In [None]:
# ============================================================
# Quantization Error Geometry: A Visual Exploration
# ============================================================
#
# Traces how quantization error regions evolve through neural
# network layers, building intuition from simple to complex cases.
#
# Structure:
#   1. Framework (this cell)
#   2. 2D Experiments: 4 experiments in 2D for clarity
#   3. 3D Extensions: unique 3D visualizations (boxes,
#      parallelepipeds, SVD shape analysis, bounding box efficiency)
#   4. Manifold Analysis: error patterns across input manifolds

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
from dataclasses import dataclass, field
from typing import List, Dict
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from itertools import product
import warnings
warnings.filterwarnings('ignore')


# ============================================================
# Configuration
# ============================================================

BITS = 8
DELTA = 1.0 / (2 ** (BITS - 1))
N_LAYERS = 4

COLORS = {
    'layer1': '#1f77b4',
    'layer2': '#ff7f0e',
    'layer3': '#2ca02c',
    'layer4': '#d62728',
    'cumulative': '#e377c2',
    'input': '#17becf',
    'error_region': '#ff6b6b',
    'reference': '#888888'
}


# ============================================================
# Data structures
# ============================================================

@dataclass
class LayerStats:
    """Statistics for a single layer."""
    layer_idx: int
    weight_matrix: np.ndarray
    spectral_norm: float
    determinant: float
    condition_number: float
    error_half_widths: np.ndarray
    error_volume: float


@dataclass
class ExperimentStats:
    """Statistics for a full experiment."""
    name: str
    input_point: np.ndarray
    layer_stats: List[LayerStats]
    cumulative_error_vertices: np.ndarray
    cumulative_error_volume: float
    bounding_box: np.ndarray
    relative_error: np.ndarray

    def summary(self):
        return {
            'name': self.name,
            'input': self.input_point.tolist(),
            'final_volume': self.cumulative_error_volume,
            'bbox': self.bounding_box.tolist(),
            'relative_error': self.relative_error.tolist(),
            'spectral_norms': [ls.spectral_norm for ls in self.layer_stats],
        }


@dataclass
class AllExperimentStats:
    """Container for all experiment results."""
    experiments: Dict[str, ExperimentStats] = field(default_factory=dict)

    def add(self, stats: ExperimentStats):
        self.experiments[stats.name] = stats

    def print_summary(self):
        print("\n" + "=" * 70)
        print("SUMMARY OF ALL EXPERIMENTS")
        print("=" * 70)
        for name, stats in self.experiments.items():
            print(f"\n{name}:")
            print(f"  Input: {stats.input_point}")
            print(f"  Final error volume: {stats.cumulative_error_volume:.6f}")
            print(f"  Bounding box: {stats.bounding_box}")
            print(f"  Relative error: {stats.relative_error}")


ALL_STATS = AllExperimentStats()


# ============================================================
# Core functions (2D)
# ============================================================

def quantize(W, delta=DELTA):
    """Quantize matrix to nearest grid point."""
    return np.round(W / delta) * delta


def get_box_vertices_2d(half_widths):
    """Get vertices of 2D box centered at origin."""
    hw = np.array(half_widths)
    return np.array([[-hw[0], -hw[1]], [-hw[0], hw[1]],
                     [hw[0], hw[1]], [hw[0], -hw[1]]])


def minkowski_sum_2d(V1, V2):
    """Minkowski sum of two 2D vertex sets."""
    sums = np.array([v1 + v2 for v1 in V1 for v2 in V2])
    if len(sums) >= 3:
        try:
            hull = ConvexHull(sums)
            return sums[hull.vertices]
        except:
            pass
    return sums


def compute_polygon_area(vertices):
    """Compute area of polygon using convex hull."""
    if len(vertices) < 3:
        return 0.0
    try:
        hull = ConvexHull(vertices)
        return hull.volume  # In 2D, 'volume' is area
    except:
        return 0.0


def transform_vertices(vertices, W):
    """Apply linear transformation W to vertices."""
    return vertices @ W.T


def draw_polygon(ax, vertices, color, alpha=0.3, edgecolor=None, linewidth=2, label=None):
    """Draw a polygon from vertices."""
    if len(vertices) < 3:
        ax.scatter(vertices[:, 0], vertices[:, 1], c=color, s=50, label=label)
        return
    try:
        hull = ConvexHull(vertices)
        hull_verts = vertices[hull.vertices]
        hull_verts = np.vstack([hull_verts, hull_verts[0]])
        ax.fill(hull_verts[:, 0], hull_verts[:, 1], color=color, alpha=alpha, label=label)
        ax.plot(hull_verts[:, 0], hull_verts[:, 1], color=edgecolor or color, linewidth=linewidth)
    except:
        ax.scatter(vertices[:, 0], vertices[:, 1], c=color, s=50, label=label)


def set_fixed_scale(ax, scale, center=(0, 0)):
    """Set fixed axis limits."""
    ax.set_xlim(center[0] - scale, center[0] + scale)
    ax.set_ylim(center[1] - scale, center[1] + scale)
    ax.set_aspect('equal')
    ax.axhline(0, color='k', linewidth=0.5)
    ax.axvline(0, color='k', linewidth=0.5)
    ax.grid(True, alpha=0.3)


# ============================================================
# Core functions (3D)
# ============================================================

def get_hypercube_vertices(half_width, dims=3):
    """Get vertices of a hypercube centered at origin."""
    return np.array(list(product([-1, 1], repeat=dims))) * half_width


def draw_box_3d(ax, center, half_widths, color, alpha=0.3, label=None):
    """Draw a 3D box centered at 'center' with given half-widths."""
    hw = np.array(half_widths)
    vertices = np.array(list(product([-1, 1], repeat=3))) * hw + center
    faces = [
        [vertices[0], vertices[1], vertices[3], vertices[2]],
        [vertices[4], vertices[5], vertices[7], vertices[6]],
        [vertices[0], vertices[1], vertices[5], vertices[4]],
        [vertices[2], vertices[3], vertices[7], vertices[6]],
        [vertices[0], vertices[2], vertices[6], vertices[4]],
        [vertices[1], vertices[3], vertices[7], vertices[5]],
    ]
    ax.add_collection3d(Poly3DCollection(
        faces, alpha=alpha, facecolor=color, edgecolor='black', linewidth=0.5
    ))
    if label:
        ax.text(center[0], center[1], center[2] + hw[2] * 1.2, label, fontsize=10)


def draw_vertices_and_hull_3d(ax, vertices, color, alpha=0.3):
    """Draw vertices and their convex hull in 3D."""
    ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2],
               c=color, s=20, alpha=0.8)
    if len(vertices) >= 4:
        try:
            hull = ConvexHull(vertices)
            for simplex in hull.simplices:
                triangle = vertices[simplex]
                ax.add_collection3d(Poly3DCollection(
                    [triangle], alpha=alpha, facecolor=color,
                    edgecolor='black', linewidth=0.5
                ))
        except:
            pass


def draw_wireframe_box(ax, half_widths, color='gray', alpha=0.5):
    """Draw wireframe of axis-aligned box for reference."""
    hw = half_widths
    edges = []
    for i in [-1, 1]:
        for j in [-1, 1]:
            edges.append([[-hw[0]*i, -hw[1]*j, -hw[2]], [-hw[0]*i, -hw[1]*j, hw[2]]])
            edges.append([[-hw[0]*i, -hw[1], -hw[2]*j], [-hw[0]*i, hw[1], -hw[2]*j]])
            edges.append([[-hw[0], -hw[1]*i, -hw[2]*j], [hw[0], -hw[1]*i, -hw[2]*j]])
    for edge in edges:
        ax.plot3D(*zip(*edge), color=color, alpha=alpha, linewidth=1)


def minkowski_sum_3d(V1, V2):
    """Minkowski sum of two 3D vertex sets."""
    sums = np.array([v1 + v2 for v1 in V1 for v2 in V2])
    if len(sums) > 4:
        try:
            hull = ConvexHull(sums)
            return sums[hull.vertices]
        except:
            return sums
    return sums


# ============================================================
# Generic experiment runner (2D)
# ============================================================

def run_experiment(name, x_input, weight_matrices, compute_error_fn):
    """
    Run a 2D experiment and collect statistics.

    Args:
        name: Experiment name
        x_input: Input point (2D)
        weight_matrices: List of weight matrices (true, pre-quantization)
        compute_error_fn: Function to compute error vertices at each layer

    Returns:
        ExperimentStats object, quantized weights
    """
    quant_weights = [quantize(W) for W in weight_matrices]

    layer_stats = []
    val = x_input.copy()
    cumulative_W = np.eye(2)
    cumulative_error_vertices = None

    for i, (W_true, W) in enumerate(zip(weight_matrices, quant_weights)):
        spectral_norm = np.linalg.norm(W, ord=2)
        det = np.linalg.det(W)
        svd = np.linalg.svd(W, compute_uv=False)
        cond = svd.max() / svd.min() if svd.min() > 0 else np.inf

        local_error_vertices = compute_error_fn(val, W, DELTA)

        cumulative_W_after = W @ cumulative_W
        try:
            inv_W = np.linalg.inv(cumulative_W_after)
            error_vertices_input = transform_vertices(local_error_vertices, inv_W)
        except:
            error_vertices_input = local_error_vertices

        if cumulative_error_vertices is None:
            cumulative_error_vertices = error_vertices_input
        else:
            cumulative_error_vertices = minkowski_sum_2d(
                cumulative_error_vertices, error_vertices_input
            )

        hw = np.abs(error_vertices_input).max(axis=0)

        layer_stats.append(LayerStats(
            layer_idx=i,
            weight_matrix=W.copy(),
            spectral_norm=spectral_norm,
            determinant=det,
            condition_number=cond,
            error_half_widths=hw,
            error_volume=compute_polygon_area(error_vertices_input)
        ))

        val = W @ val
        cumulative_W = cumulative_W_after

    bbox_min = cumulative_error_vertices.min(axis=0)
    bbox_max = cumulative_error_vertices.max(axis=0)
    rel_error = (bbox_max - bbox_min) / (2 * np.abs(x_input) + 1e-10)

    stats = ExperimentStats(
        name=name,
        input_point=x_input.copy(),
        layer_stats=layer_stats,
        cumulative_error_vertices=cumulative_error_vertices,
        cumulative_error_volume=compute_polygon_area(cumulative_error_vertices),
        bounding_box=np.array([bbox_min, bbox_max]),
        relative_error=rel_error
    )

    ALL_STATS.add(stats)
    return stats, quant_weights


print(f"Framework loaded. BITS={BITS}, DELTA={DELTA:.6f}")

In [None]:
# ============================================================
# 2D EXPERIMENTS
# ============================================================
#
# Exp 1: Uniform diagonal weights (baseline)
# Exp 2: Non-uniform diagonal weights (per-channel variation)
# Exp 3: Full matrices (channel mixing, rotation/shear)
# Exp 4: Multiple input points (error manifold)


# --- Error functions ---

def diagonal_error_fn(val, W, delta):
    """Error function for diagonal weights — axis-aligned box."""
    hw = (delta / 2) * np.abs(val)
    return get_box_vertices_2d(hw)


def full_matrix_error_fn(val, W, delta):
    """
    Error function for full matrices.

    Each output dim: error = sum_j W_err[i,j] * val[j]
    with W_err[i,j] independent in [-delta/2, delta/2].
    Half-width per output dim = delta/2 * L1_norm(val).
    """
    l1_norm = np.sum(np.abs(val))
    hw = (delta / 2) * l1_norm * np.ones(2)
    return get_box_vertices_2d(hw)


# --- Experiment definitions ---

def run_experiment_1(x_input):
    """Uniform diagonal weights."""
    weights = [np.eye(2) * w for w in [0.9, 1.1, 0.85, 1.05]]
    stats, qw = run_experiment("Exp1: Uniform Diagonal", x_input, weights, diagonal_error_fn)
    print(f"Exp 1: volume={stats.cumulative_error_volume:.6f}, rel_error={stats.relative_error}")
    return stats, qw


def run_experiment_2(x_input):
    """Non-uniform diagonal weights."""
    weights = [np.diag(d) for d in [[0.8, 1.2], [1.1, 0.7], [0.9, 1.1], [1.2, 0.8]]]
    stats, qw = run_experiment("Exp2: Non-Uniform Diagonal", x_input, weights, diagonal_error_fn)
    print(f"Exp 2: volume={stats.cumulative_error_volume:.6f}, rel_error={stats.relative_error}")
    return stats, qw


def run_experiment_3(x_input):
    """Full matrices with off-diagonal elements."""
    weights = [
        np.array([[0.9, 0.2], [0.1, 1.0]]),
        np.array([[0.95, -0.15], [0.2, 0.85]]),
        np.array([[1.0, 0.1], [-0.1, 0.9]]),
        np.array([[0.85, 0.15], [0.1, 1.05]]),
    ]
    stats, qw = run_experiment("Exp3: Full Matrices", x_input, weights, full_matrix_error_fn)
    print(f"Exp 3: volume={stats.cumulative_error_volume:.6f}, rel_error={stats.relative_error}")
    return stats, qw


def run_experiment_4(base_weights):
    """Multiple input points on a circle manifold."""
    quant_weights = [quantize(W) for W in base_weights]
    n_points = 32
    theta = np.linspace(0, 2 * np.pi, n_points, endpoint=False)
    radius = 20
    circle_points = np.column_stack([radius * np.cos(theta), radius * np.sin(theta)])

    results = []
    for x in circle_points:
        val = x.copy()
        cumulative_W = np.eye(2)
        cumulative_error_vertices = None

        for W in quant_weights:
            local_error_vertices = full_matrix_error_fn(val, W, DELTA)
            cumulative_W_after = W @ cumulative_W
            try:
                inv_W = np.linalg.inv(cumulative_W_after)
                error_vertices_input = transform_vertices(local_error_vertices, inv_W)
            except:
                error_vertices_input = local_error_vertices

            if cumulative_error_vertices is None:
                cumulative_error_vertices = error_vertices_input
            else:
                cumulative_error_vertices = minkowski_sum_2d(
                    cumulative_error_vertices, error_vertices_input
                )
            val = W @ val
            cumulative_W = cumulative_W_after

        error_magnitude = np.max(np.linalg.norm(cumulative_error_vertices, axis=1))
        results.append({
            'input': x.copy(),
            'error_vertices': cumulative_error_vertices.copy(),
            'error_magnitude': error_magnitude,
            'error_volume': compute_polygon_area(cumulative_error_vertices)
        })

    magnitudes = [r['error_magnitude'] for r in results]
    print(f"Exp 4: error range [{min(magnitudes):.4f}, {max(magnitudes):.4f}], "
          f"variation ratio {max(magnitudes)/min(magnitudes):.2f}x")
    return results, circle_points, quant_weights


# ============================================================
# Plotting functions (2D)
# ============================================================

def plot_experiment_1_2(stats1, stats2, scale):
    """Compare Exp 1 (uniform) and Exp 2 (non-uniform)."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    for row, stats, title in [(0, stats1, 'Exp 1: Uniform Diagonal'),
                               (1, stats2, 'Exp 2: Non-Uniform Diagonal')]:
        # Error region
        ax = axes[row, 0]
        draw_polygon(ax, stats.cumulative_error_vertices, COLORS['error_region'], alpha=0.4)
        set_fixed_scale(ax, scale)
        ax.set_title(f"{title}\nVolume: {stats.cumulative_error_volume:.6f}")
        ax.set_xlabel('Dim 0'); ax.set_ylabel('Dim 1')

        # Per-layer contributions
        ax = axes[row, 1]
        layers = [ls.layer_idx + 1 for ls in stats.layer_stats]
        x_pos = np.arange(len(layers))
        hw0 = [ls.error_half_widths[0] for ls in stats.layer_stats]
        hw1 = [ls.error_half_widths[1] for ls in stats.layer_stats]
        ax.bar(x_pos - 0.2, hw0, 0.4, label='Dim 0', color=COLORS['layer1'])
        ax.bar(x_pos + 0.2, hw1, 0.4, label='Dim 1', color=COLORS['layer2'])
        ax.set_xticks(x_pos); ax.set_xticklabels([f'L{l}' for l in layers])
        ax.set_ylabel('Error half-width'); ax.set_title(f'{title}: Per-layer error')
        ax.legend(); ax.grid(True, alpha=0.3)

        # Relative error
        ax = axes[row, 2]
        ax.bar(['Dim 0', 'Dim 1'], stats.relative_error * 100,
               color=[COLORS['layer1'], COLORS['layer2']])
        ax.set_ylabel('Relative error (%)')
        subtitle = '(Equal for uniform)' if row == 0 else '(Different per channel!)'
        ax.set_title(f'{title}: Relative error\n{subtitle}')
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('plots/exp1_2_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_experiment_3(stats3, stats1, scale):
    """Plot Experiment 3 results."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Error region comparison
    ax = axes[0]
    draw_polygon(ax, stats1.cumulative_error_vertices, COLORS['layer1'], alpha=0.3, label='Diagonal')
    draw_polygon(ax, stats3.cumulative_error_vertices, COLORS['error_region'], alpha=0.4, label='Full')
    set_fixed_scale(ax, scale)
    ax.set_title('Error region comparison'); ax.legend()

    # SVD shape analysis
    ax = axes[1]
    centered = stats3.cumulative_error_vertices - stats3.cumulative_error_vertices.mean(axis=0)
    U, S, Vt = np.linalg.svd(centered, full_matrices=False)
    ax.bar(['PC1', 'PC2'], S, color=[COLORS['layer1'], COLORS['layer2']])
    ax.set_ylabel('Singular value')
    ax.set_title(f'Error region shape (SVD)\nCondition: {S[0]/S[1]:.2f}')
    ax.grid(True, alpha=0.3)

    # Principal directions
    ax = axes[2]
    center = stats3.cumulative_error_vertices.mean(axis=0)
    draw_polygon(ax, stats3.cumulative_error_vertices, COLORS['error_region'], alpha=0.3)
    for i, (s, v) in enumerate(zip(S, Vt)):
        ax.arrow(center[0], center[1], v[0]*s*0.8, v[1]*s*0.8,
                head_width=scale*0.03, color=['blue', 'green'][i], linewidth=2,
                label=f'PC{i+1}: [{v[0]:.2f}, {v[1]:.2f}]')
    set_fixed_scale(ax, scale)
    ax.set_title('Principal directions\n(Error is anisotropic)')
    ax.legend(loc='upper left', fontsize=8)

    plt.tight_layout()
    plt.savefig('plots/exp3_full_matrices.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_experiment_4(results, circle_points, scale):
    """Plot Experiment 4 results."""
    fig, axes = plt.subplots(2, 2, figsize=(14, 14))

    magnitudes = [r['error_magnitude'] for r in results]

    # Error magnitude around the circle
    ax = axes[0, 0]
    scatter = ax.scatter(circle_points[:, 0], circle_points[:, 1],
                        c=magnitudes, cmap='hot', s=100, edgecolors='black')
    plt.colorbar(scatter, ax=ax, label='Error magnitude')
    ax.plot(circle_points[:, 0], circle_points[:, 1], 'b-', alpha=0.3)
    ax.set_title('Circle manifold colored by error')
    ax.set_aspect('equal'); ax.grid(True, alpha=0.3)

    # Error magnitude vs angle
    ax = axes[0, 1]
    angles = np.arctan2(circle_points[:, 1], circle_points[:, 0])
    ax.plot(np.degrees(angles), magnitudes, 'o-', linewidth=2, markersize=6)
    ax.set_xlabel('Angle (degrees)'); ax.set_ylabel('Error magnitude')
    ax.set_title('Error varies with direction'); ax.grid(True, alpha=0.3)

    # Selected error regions
    ax = axes[1, 0]
    n_selected = 8
    indices = np.linspace(0, len(results)-1, n_selected, dtype=int)
    colors_sel = plt.cm.hsv(np.linspace(0, 1, n_selected))
    for idx, color in zip(indices, colors_sel):
        draw_polygon(ax, results[idx]['error_vertices'], color, alpha=0.3, linewidth=1)
    set_fixed_scale(ax, scale)
    ax.set_title('Error regions for 8 points around circle')

    # Manifold with error regions overlaid
    ax = axes[1, 1]
    error_scale = 0.5
    for r in results[::2]:
        vertices = r['error_vertices'] * error_scale + r['input']
        draw_polygon(ax, vertices, COLORS['error_region'], alpha=0.2, linewidth=0.5)
    ax.plot(circle_points[:, 0], circle_points[:, 1], 'b-', linewidth=2, label='Input manifold')
    ax.scatter(circle_points[:, 0], circle_points[:, 1], c='blue', s=20, zorder=5)
    ax.set_title('Input manifold with error \"tubes\"')
    ax.set_aspect('equal'); ax.grid(True, alpha=0.3); ax.legend()

    plt.tight_layout()
    plt.savefig('plots/exp4_error_manifold.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_summary(all_stats):
    """Summary comparison of experiments 1-3."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    exp_names = list(all_stats.experiments.keys())[:3]

    # Volume comparison
    ax = axes[0]
    volumes = [all_stats.experiments[name].cumulative_error_volume for name in exp_names]
    ax.bar(range(len(volumes)), volumes, color=[COLORS['layer1'], COLORS['layer2'], COLORS['layer3']])
    ax.set_xticks(range(len(volumes)))
    ax.set_xticklabels(['Uniform\nDiagonal', 'Non-Uniform\nDiagonal', 'Full\nMatrices'], fontsize=9)
    ax.set_ylabel('Error region volume')
    ax.set_title('Total error volume comparison')
    ax.grid(True, alpha=0.3)

    # Overlay all error regions
    ax = axes[1]
    colors = [COLORS['layer1'], COLORS['layer2'], COLORS['layer3']]
    alphas = [0.4, 0.3, 0.2]
    max_extent = 0
    for name, color, alpha in zip(exp_names, colors, alphas):
        stats = all_stats.experiments[name]
        draw_polygon(ax, stats.cumulative_error_vertices, color, alpha=alpha,
                    label=name.split(':')[1].strip())
        max_extent = max(max_extent, np.abs(stats.cumulative_error_vertices).max())
    set_fixed_scale(ax, max_extent * 1.2)
    ax.set_title('All error regions overlaid')
    ax.legend(loc='upper left', fontsize=8)

    # Spectral norms
    ax = axes[2]
    x_pos = np.arange(N_LAYERS)
    width = 0.25
    for i, (name, color) in enumerate(zip(exp_names, colors)):
        stats = all_stats.experiments[name]
        norms = [ls.spectral_norm for ls in stats.layer_stats]
        ax.bar(x_pos + i*width, norms, width, label=name.split(':')[1].strip(),
               color=color, alpha=0.7)
    ax.set_xticks(x_pos + width)
    ax.set_xticklabels([f'L{i+1}' for i in range(N_LAYERS)])
    ax.set_ylabel('Spectral norm')
    ax.set_title('Weight spectral norms by layer')
    ax.legend(fontsize=8); ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('plots/all_experiments_summary.png', dpi=150, bbox_inches='tight')
    plt.show()


# ============================================================
# Run 2D experiments
# ============================================================

x_input = np.array([10.0, 20.0])

print("=" * 70)
print(f"2D EXPERIMENTS — Input: {x_input}, Bits: {BITS}, Delta: {DELTA}")
print("=" * 70)

stats1, weights1 = run_experiment_1(x_input)
stats2, weights2 = run_experiment_2(x_input)
stats3, weights3 = run_experiment_3(x_input)

# Global scale for consistent 2D plots
GLOBAL_ERROR_SCALE = max(
    np.abs(s.cumulative_error_vertices).max()
    for s in [stats1, stats2, stats3]
) * 1.3

exp3_weights = [
    np.array([[0.9, 0.2], [0.1, 1.0]]),
    np.array([[0.95, -0.15], [0.2, 0.85]]),
    np.array([[1.0, 0.1], [-0.1, 0.9]]),
    np.array([[0.85, 0.15], [0.1, 1.05]]),
]
results4, circle_points, weights4 = run_experiment_4(exp3_weights)

# Plot
plot_experiment_1_2(stats1, stats2, GLOBAL_ERROR_SCALE)
plot_experiment_3(stats3, stats1, GLOBAL_ERROR_SCALE)
plot_experiment_4(results4, circle_points, GLOBAL_ERROR_SCALE)
plot_summary(ALL_STATS)
ALL_STATS.print_summary()

print("\nKEY TAKEAWAYS:")
print("1. Uniform diagonal: relative error constant across channels")
print("2. Non-uniform diagonal: channels accumulate error at different rates")
print("3. Full matrices: error region tilted/sheared, PCA reveals anisotropy")
print("4. Error manifold: magnitude and shape vary with input position and direction")

In [None]:
# ============================================================
# 3D EXTENSIONS
# ============================================================
#
# Same 3 experiments but in 3D, revealing:
#   - 3D box/parallelepiped geometry
#   - Channel sensitivity analysis
#   - SVD shape analysis with 3 singular values
#   - Bounding box efficiency
#   - 2D projections

np.random.seed(42)

x_input_3d = np.array([10.0, 20.0, 30.0])


# ============================================================
# 3D error computation
# ============================================================

def compute_error_boxes_diagonal_3d(x, weight_channels, delta=DELTA):
    """
    Compute per-layer error boxes for diagonal weight networks (3D).

    Args:
        x: Input vector (3D)
        weight_channels: (n_layers, 3) array of per-channel weights
    Returns:
        List of per-layer error box info
    """
    quant_w = quantize(weight_channels, delta)
    boxes = []
    val = x.copy()
    cumulative_weight = np.ones(3)

    for i in range(len(quant_w)):
        w = quant_w[i]
        box_at_layer = (delta / 2) * np.abs(val)
        cumulative_weight_after = cumulative_weight * w
        box_in_input_space = box_at_layer / np.abs(cumulative_weight_after)

        boxes.append({
            'layer': i + 1,
            'hw_input': box_in_input_space.copy(),
            'hw_output': box_at_layer.copy(),
            'value': val.copy(),
            'weight': w.copy(),
        })
        val = w * val
        cumulative_weight = cumulative_weight_after

    return boxes, quant_w


def trace_error_geometry_3d(x, quant_weights, delta=DELTA):
    """
    Trace error region geometry through 3D full-matrix layers.

    At each layer, the error from weight quantization is an axis-aligned box
    in output space. Mapped back to input space and accumulated via
    Minkowski sum, the cumulative error becomes a polytope.
    """
    history = []
    val = x.copy()
    cumulative_transform = np.eye(3)

    for i, W in enumerate(quant_weights):
        l1_norm = np.sum(np.abs(val))
        hw = delta / 2 * l1_norm * np.ones(3)
        local_vertices = get_hypercube_vertices(1.0, dims=3) * hw

        cumulative_transform_after = W @ cumulative_transform
        try:
            inv_t = np.linalg.inv(cumulative_transform_after)
            error_vertices_input = local_vertices @ inv_t.T
        except np.linalg.LinAlgError:
            error_vertices_input = local_vertices

        if i == 0:
            total_vertices = error_vertices_input
        else:
            total_vertices = minkowski_sum_3d(total_vertices, error_vertices_input)

        history.append({
            'layer': i + 1,
            'value': val.copy(),
            'error_vertices_input': error_vertices_input.copy(),
            'cumulative_vertices': total_vertices.copy(),
            'hw_local': hw.copy(),
        })

        val = W @ val
        cumulative_transform = cumulative_transform_after

    return history


# ============================================================
# Exp 1 (3D): Uniform diagonal — cumulative error boxes
# ============================================================

print("=" * 70)
print(f"3D EXPERIMENTS — Input: {x_input_3d}, Bits: {BITS}")
print("=" * 70)

uniform_weights_3d = np.array([[0.8, 0.8, 0.8],
                                [1.2, 1.2, 1.2],
                                [0.9, 0.9, 0.9],
                                [1.1, 1.1, 1.1]])

boxes_uniform, qw_uniform = compute_error_boxes_diagonal_3d(x_input_3d, uniform_weights_3d)

total_hw_uniform = sum(b['hw_input'] for b in boxes_uniform)
print(f"\nExp 1 (3D Uniform): total half-widths = {total_hw_uniform}")
print(f"Relative error per channel: {total_hw_uniform / x_input_3d * 100}%")
print("(Same % for all channels — uniform weights)")

# Plot: nested cumulative error boxes + error growth
fig = plt.figure(figsize=(14, 5))

# Cumulative Minkowski sum (nested boxes)
ax1 = fig.add_subplot(131, projection='3d')
colors_3d = plt.cm.viridis(np.linspace(0.2, 0.8, len(boxes_uniform)))
cumulative_hw = np.zeros(3)
for b, color in zip(boxes_uniform, colors_3d):
    cumulative_hw = cumulative_hw + b['hw_input']
    draw_box_3d(ax1, np.zeros(3), cumulative_hw, color, alpha=0.2)
ax1.set_xlabel('Ch0 (x=10)'); ax1.set_ylabel('Ch1 (x=20)'); ax1.set_zlabel('Ch2 (x=30)')
ax1.set_title('Cumulative error box\n(Minkowski sum, nested)')

# Final box with axis lines
ax2 = fig.add_subplot(132, projection='3d')
draw_box_3d(ax2, np.zeros(3), total_hw_uniform, 'red', alpha=0.4)
for i, (hw, c, label) in enumerate(zip(total_hw_uniform, ['b', 'g', 'r'],
                                        ['Ch0 (x=10)', 'Ch1 (x=20)', 'Ch2 (x=30)'])):
    pts = np.zeros((2, 3))
    pts[0, i] = -hw; pts[1, i] = hw
    ax2.plot3D(pts[:, 0], pts[:, 1], pts[:, 2], f'{c}-', linewidth=3,
              label=f'{label}: ±{hw:.4f}')
ax2.set_title('Final error box\nLarger input → larger error')
ax2.legend(loc='upper left', fontsize=7)

# Error growth per channel
ax3 = fig.add_subplot(133)
for ch in range(3):
    cumulative = np.cumsum([b['hw_input'][ch] for b in boxes_uniform])
    percentage = 100 * cumulative / x_input_3d[ch]
    ax3.plot([b['layer'] for b in boxes_uniform], percentage, 'o-',
             linewidth=2, markersize=8, label=f'Ch{ch}')
ax3.set_xlabel('Layer'); ax3.set_ylabel('Error as % of input')
ax3.set_title('Relative error (same for all channels)')
ax3.legend(); ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/exp1_3d_uniform.png', dpi=150, bbox_inches='tight')
plt.show()


# ============================================================
# Exp 2 (3D): Non-uniform diagonal — channel sensitivity
# ============================================================

nonuniform_weights_3d = np.array([
    [0.8, 1.2, 0.5],   # Ch1 amplified, Ch2 shrunk
    [1.1, 0.7, 1.3],   # Ch0,2 amplified, Ch1 shrunk
    [0.9, 1.1, 0.9],   # Mild
    [1.2, 0.8, 1.1],   # Ch0 amplified, Ch1 shrunk
])

boxes_nonuniform, qw_nonuniform = compute_error_boxes_diagonal_3d(x_input_3d, nonuniform_weights_3d)
total_hw_nonuniform = sum(b['hw_input'] for b in boxes_nonuniform)

print(f"\nExp 2 (3D Non-Uniform): total half-widths = {total_hw_nonuniform}")

fig = plt.figure(figsize=(18, 5))

# Overlay both boxes
ax1 = fig.add_subplot(131, projection='3d')
max_hw = max(total_hw_nonuniform.max(), total_hw_uniform.max())
draw_box_3d(ax1, np.zeros(3), total_hw_nonuniform, 'red', alpha=0.3)
draw_box_3d(ax1, np.zeros(3), total_hw_uniform, 'blue', alpha=0.3)
ax1.set_xlabel('Ch0'); ax1.set_ylabel('Ch1'); ax1.set_zlabel('Ch2')
ax1.set_title('Overlay\nRed=Non-uniform, Blue=Uniform')
ax1.set_xlim(-max_hw*1.2, max_hw*1.2)
ax1.set_ylim(-max_hw*1.2, max_hw*1.2)
ax1.set_zlim(-max_hw*1.2, max_hw*1.2)

# Relative error comparison
ax2 = fig.add_subplot(132)
uniform_rel = total_hw_uniform / x_input_3d * 100
nonuniform_rel = total_hw_nonuniform / x_input_3d * 100
x_pos = np.arange(3)
ax2.bar(x_pos - 0.2, uniform_rel, 0.4, label='Uniform', color='blue', alpha=0.7)
ax2.bar(x_pos + 0.2, nonuniform_rel, 0.4, label='Non-uniform', color='red', alpha=0.7)
ax2.set_xticks(x_pos)
ax2.set_xticklabels([f'Ch{i} (x={int(x_input_3d[i])})' for i in range(3)])
ax2.set_ylabel('Relative error (%)')
ax2.set_title('Relative error comparison')
ax2.legend(); ax2.grid(True, alpha=0.3)

# Channel sensitivity
ax3 = fig.add_subplot(133)
mean_rel = nonuniform_rel.mean()
sensitivity = nonuniform_rel / mean_rel
ax3.bar(x_pos, sensitivity, color=['green' if s < 1 else 'red' for s in sensitivity], alpha=0.7)
ax3.axhline(1.0, color='gray', linestyle='--', linewidth=2, label='Average')
ax3.set_xticks(x_pos)
ax3.set_xticklabels([f'Ch{i}' for i in range(3)])
ax3.set_ylabel('Sensitivity (relative to mean)')
ax3.set_title('Channel error sensitivity\nGreen=below avg, Red=above avg')
for i, s in enumerate(sensitivity):
    ax3.annotate(f'{s:.2f}x', (i, s + 0.05), ha='center', fontsize=12)
ax3.legend(); ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/exp2_3d_sensitivity.png', dpi=150, bbox_inches='tight')
plt.show()


# ============================================================
# Exp 3 (3D): Full matrices — parallelepipeds, SVD, projections
# ============================================================

full_weights_3d = [
    np.array([[0.9, 0.1, 0.0], [0.1, 1.1, 0.1], [0.0, 0.1, 0.8]]),
    np.array([[0.8, -0.3, 0.1], [0.3, 0.8, -0.2], [-0.1, 0.2, 0.9]]),
    np.array([[1.1, 0.2, 0.0], [0.0, 0.9, 0.2], [0.1, 0.0, 1.0]]),
    np.array([[0.9, 0.2, -0.1], [-0.2, 1.0, 0.1], [0.1, -0.1, 0.85]]),
]

quant_full_3d = [quantize(W) for W in full_weights_3d]
history_3d = trace_error_geometry_3d(x_input_3d, quant_full_3d)

# Also compute diagonal-only case for comparison
diagonal_3d = [np.diag(np.diag(W)) for W in quant_full_3d]
history_diag_3d = trace_error_geometry_3d(x_input_3d, diagonal_3d)

final_verts = history_3d[-1]['cumulative_vertices']
final_verts_diag = history_diag_3d[-1]['cumulative_vertices']

print(f"\nExp 3 (3D Full Matrices):")
print(f"  Final error vertices: {len(final_verts)}")
print(f"  Bounding box: {final_verts.min(axis=0)} to {final_verts.max(axis=0)}")

# Figure 1: Error region evolution
fig = plt.figure(figsize=(18, 5))
for i, h in enumerate(history_3d):
    ax = fig.add_subplot(1, 4, i+1, projection='3d')
    draw_vertices_and_hull_3d(ax, h['error_vertices_input'], colors_3d[i], alpha=0.4)
    draw_vertices_and_hull_3d(ax, h['cumulative_vertices'], 'red', alpha=0.2)
    me = np.abs(h['cumulative_vertices']).max() * 1.2
    ax.set_xlim(-me, me); ax.set_ylim(-me, me); ax.set_zlim(-me, me)
    ax.set_xlabel('Ch0'); ax.set_ylabel('Ch1'); ax.set_zlabel('Ch2')
    ax.set_title(f"Layer {h['layer']}")
plt.suptitle('3D Error Region Evolution (green=layer contribution, red=cumulative)', fontsize=12)
plt.tight_layout()
plt.savefig('plots/exp3_3d_evolution.png', dpi=150, bbox_inches='tight')
plt.show()

# Figure 2: Final region vs bounding box
fig = plt.figure(figsize=(16, 6))
bbox_hw = (final_verts.max(axis=0) - final_verts.min(axis=0)) / 2
max_extent = np.abs(final_verts).max() * 1.2

ax1 = fig.add_subplot(131, projection='3d')
draw_vertices_and_hull_3d(ax1, final_verts, 'red', alpha=0.4)
ax1.set_xlim(-max_extent, max_extent); ax1.set_ylim(-max_extent, max_extent)
ax1.set_zlim(-max_extent, max_extent)
ax1.set_title('Actual error region\n(Non-axis-aligned polytope)')

ax2 = fig.add_subplot(132, projection='3d')
box_verts = get_hypercube_vertices(1.0) * bbox_hw
draw_vertices_and_hull_3d(ax2, box_verts, 'blue', alpha=0.4)
ax2.set_xlim(-max_extent, max_extent); ax2.set_ylim(-max_extent, max_extent)
ax2.set_zlim(-max_extent, max_extent)
ax2.set_title('Bounding box\n(Axis-aligned approximation)')

ax3 = fig.add_subplot(133, projection='3d')
draw_vertices_and_hull_3d(ax3, final_verts, 'red', alpha=0.3)
draw_wireframe_box(ax3, bbox_hw, 'blue', alpha=0.8)
ax3.set_xlim(-max_extent, max_extent); ax3.set_ylim(-max_extent, max_extent)
ax3.set_zlim(-max_extent, max_extent)
ax3.set_title('Overlay\nRed=Actual, Blue=Bounding box')
for ax in [ax1, ax2, ax3]:
    ax.set_xlabel('Ch0'); ax.set_ylabel('Ch1'); ax.set_zlabel('Ch2')

plt.tight_layout()
plt.savefig('plots/exp3_3d_vs_bbox.png', dpi=150, bbox_inches='tight')
plt.show()

# Figure 3: SVD, volume growth, bounding box efficiency, 2D projections
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# SVD of final error region
centered = final_verts - final_verts.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
centered_diag = final_verts_diag - final_verts_diag.mean(axis=0)
U_d, S_d, Vt_d = np.linalg.svd(centered_diag, full_matrices=False)

ax = axes[0, 0]
x_pos = np.arange(3)
ax.bar(x_pos - 0.175, S, 0.35, label='Full matrix', color='red', alpha=0.7)
ax.bar(x_pos + 0.175, S_d, 0.35, label='Diagonal only', color='blue', alpha=0.7)
ax.set_xticks(x_pos); ax.set_xticklabels(['PC1', 'PC2', 'PC3'])
ax.set_ylabel('Singular value')
ax.set_title(f'Error shape: Full vs Diagonal\nCondition: {S[0]/S[2]:.2f} vs {S_d[0]/S_d[2]:.2f}')
ax.legend(); ax.grid(True, alpha=0.3)

# Volume growth
ax = axes[0, 1]
volumes_full = []
volumes_diag = []
for h_f, h_d in zip(history_3d, history_diag_3d):
    try: volumes_full.append(ConvexHull(h_f['cumulative_vertices']).volume)
    except: volumes_full.append(0)
    try: volumes_diag.append(ConvexHull(h_d['cumulative_vertices']).volume)
    except: volumes_diag.append(0)

layers = [h['layer'] for h in history_3d]
ax.plot(layers, volumes_full, 'o-', linewidth=2, label='Full matrix', color='red')
ax.plot(layers, volumes_diag, 's-', linewidth=2, label='Diagonal only', color='blue')
ax.set_xlabel('Layer'); ax.set_ylabel('Error region volume')
ax.set_title('Error volume growth'); ax.legend(); ax.grid(True, alpha=0.3)

# Bounding box efficiency
ax = axes[1, 0]
efficiencies = []
for h in history_3d:
    v = h['cumulative_vertices']
    bbox_vol = np.prod(v.max(axis=0) - v.min(axis=0))
    try: actual_vol = ConvexHull(v).volume
    except: actual_vol = bbox_vol
    efficiencies.append(actual_vol / bbox_vol if bbox_vol > 0 else 1)

ax.bar(layers, efficiencies, color='purple', alpha=0.7)
ax.axhline(1.0, color='gray', linestyle='--', label='Perfect efficiency (cube)')
ax.set_xlabel('Layer'); ax.set_ylabel('Actual / Bounding box volume')
ax.set_title(f'Bounding box efficiency\nFinal: {efficiencies[-1]:.3f} '
             f'(overestimates by {100*(1/efficiencies[-1]-1):.0f}%)')
ax.set_ylim(0, 1.2); ax.legend(); ax.grid(True, alpha=0.3)

# 2D projections
ax = axes[1, 1]
projections = [(0, 1, 'Ch0-Ch1'), (0, 2, 'Ch0-Ch2'), (1, 2, 'Ch1-Ch2')]
proj_colors = [COLORS['layer1'], COLORS['layer2'], COLORS['layer3']]
for (i, j, name), color in zip(projections, proj_colors):
    proj_full = final_verts[:, [i, j]]
    proj_diag = final_verts_diag[:, [i, j]]
    try:
        hull = ConvexHull(proj_full)
        hv = proj_full[hull.vertices]
        hv = np.vstack([hv, hv[0]])
        ax.fill(hv[:, 0], hv[:, 1], color=color, alpha=0.2, label=f'Full {name}')
        ax.plot(hv[:, 0], hv[:, 1], color=color, linewidth=2)
    except: pass
    try:
        hull_d = ConvexHull(proj_diag)
        hv_d = proj_diag[hull_d.vertices]
        hv_d = np.vstack([hv_d, hv_d[0]])
        ax.plot(hv_d[:, 0], hv_d[:, 1], color=color, linewidth=1, linestyle='--')
    except: pass
ax.set_xlabel('Dim A'); ax.set_ylabel('Dim B')
ax.set_title('2D projections of 3D error\n(Solid=full, Dashed=diagonal)')
ax.legend(fontsize=7); ax.grid(True, alpha=0.3)
ax.set_aspect('equal')

plt.tight_layout()
plt.savefig('plots/exp3_3d_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n3D SVD principal directions:")
for i, v in enumerate(Vt):
    print(f"  PC{i+1} (σ={S[i]:.6f}): [{v[0]:.4f}, {v[1]:.4f}, {v[2]:.4f}]")
print(f"\nFull matrix volume: {volumes_full[-1]:.6f}")
print(f"Diagonal-only volume: {volumes_diag[-1]:.6f}")
print(f"Ratio: {volumes_full[-1]/volumes_diag[-1]:.3f}x")

In [None]:
# ============================================================
# MANIFOLD ANALYSIS
# ============================================================
#
# How does quantization error vary across different input manifolds?
# Tests circle, ellipse, line, spiral, figure-eight, two-blobs,
# and grid inputs.
#
# Key visualization: input manifold → error manifold.
# Since the network is linear (no activations), the error is a
# linear function of input: error(x) = (Q_product - W_product) @ x.
# So circles map to ellipses, lines to lines, etc.


def make_manifold(name, n_points=32, **kwargs):
    """Generate points on various 2D manifolds."""
    if name == "circle":
        radius = kwargs.get('radius', 20)
        theta = np.linspace(0, 2*np.pi, n_points, endpoint=False)
        points = np.column_stack([radius * np.cos(theta), radius * np.sin(theta)])
        metadata = {'radius': radius, 'type': 'closed'}
    elif name == "ellipse":
        a, b = kwargs.get('a', 25), kwargs.get('b', 10)
        theta = np.linspace(0, 2*np.pi, n_points, endpoint=False)
        points = np.column_stack([a * np.cos(theta), b * np.sin(theta)])
        metadata = {'a': a, 'b': b, 'type': 'closed'}
    elif name == "line":
        start = np.array(kwargs.get('start', [-25, -10]))
        end = np.array(kwargs.get('end', [25, 10]))
        t = np.linspace(0, 1, n_points)
        points = start + t[:, np.newaxis] * (end - start)
        metadata = {'type': 'open'}
    elif name == "spiral":
        turns = kwargs.get('turns', 2)
        r_min, r_max = kwargs.get('r_min', 5), kwargs.get('r_max', 25)
        theta = np.linspace(0, turns * 2 * np.pi, n_points)
        r = np.linspace(r_min, r_max, n_points)
        points = np.column_stack([r * np.cos(theta), r * np.sin(theta)])
        metadata = {'type': 'open'}
    elif name == "figure_eight":
        scale = kwargs.get('scale', 15)
        t = np.linspace(0, 2*np.pi, n_points, endpoint=False)
        points = np.column_stack([scale * np.sin(t), scale * np.sin(t) * np.cos(t)])
        metadata = {'type': 'closed'}
    elif name == "two_blobs":
        n_each = n_points // 2
        c1 = np.array(kwargs.get('center1', [-15, 0]))
        c2 = np.array(kwargs.get('center2', [15, 0]))
        std = kwargs.get('std', 5)
        blob1 = np.random.randn(n_each, 2) * std + c1
        blob2 = np.random.randn(n_points - n_each, 2) * std + c2
        points = np.vstack([blob1, blob2])
        metadata = {'type': 'clusters'}
    elif name == "grid":
        extent = kwargs.get('extent', 25)
        n_side = int(np.sqrt(n_points))
        x = np.linspace(-extent, extent, n_side)
        y = np.linspace(-extent, extent, n_side)
        xx, yy = np.meshgrid(x, y)
        points = np.column_stack([xx.ravel(), yy.ravel()])
        metadata = {'extent': extent, 'n_side': n_side, 'type': 'area'}
    else:
        raise ValueError(f"Unknown manifold: {name}")
    return points, metadata


def compute_pointwise_errors(points, true_weights, quant_weights):
    """Compute the actual error vector for each input point.

    Since this network has no activations, the error is a linear function:
    error(x) = (Q_n...Q_1 - W_n...W_1) @ x

    Returns (errors, W_error) where errors is (n_points, 2) and
    W_error is the 2x2 error transform matrix.
    """
    W_float = np.eye(2)
    W_quant = np.eye(2)
    for Wt, Wq in zip(true_weights, quant_weights):
        W_float = Wt @ W_float
        W_quant = Wq @ W_quant
    W_error = W_quant - W_float
    errors = points @ W_error.T
    return errors, W_error


def compute_manifold_errors(points, quant_weights, delta=DELTA):
    """Compute error region statistics for all points on a manifold."""
    results = []
    for x in points:
        val = x.copy()
        cumulative_W = np.eye(2)
        cumulative_error_vertices = None

        for W in quant_weights:
            l1_norm = np.sum(np.abs(val))
            hw = (delta / 2) * l1_norm
            local_vertices = get_box_vertices_2d([hw, hw])

            cumulative_W_after = W @ cumulative_W
            try:
                inv_W = np.linalg.inv(cumulative_W_after)
                error_vertices_input = transform_vertices(local_vertices, inv_W)
            except:
                error_vertices_input = local_vertices

            if cumulative_error_vertices is None:
                cumulative_error_vertices = error_vertices_input
            else:
                cumulative_error_vertices = minkowski_sum_2d(
                    cumulative_error_vertices, error_vertices_input
                )
            val = W @ val
            cumulative_W = cumulative_W_after

        results.append({
            'input': x.copy(),
            'error_vertices': cumulative_error_vertices.copy(),
            'error_magnitude': np.max(np.linalg.norm(cumulative_error_vertices, axis=1)),
            'error_volume': compute_polygon_area(cumulative_error_vertices),
            'l1_norm': np.sum(np.abs(x)),
            'l2_norm': np.linalg.norm(x)
        })
    return results


def run_all_manifolds(quant_weights, manifold_names=None, n_points=48):
    """Run error analysis across multiple manifolds."""
    if manifold_names is None:
        manifold_names = ['circle', 'ellipse', 'line', 'spiral', 'figure_eight', 'two_blobs']

    all_results = {}
    for name in manifold_names:
        print(f"  Processing manifold: {name}")
        points, metadata = make_manifold(name, n_points=n_points)
        results = compute_manifold_errors(points, quant_weights)

        magnitudes = [r['error_magnitude'] for r in results]
        l1_norms = [r['l1_norm'] for r in results]

        all_results[name] = {
            'points': points,
            'metadata': metadata,
            'results': results,
            'stats': {
                'error_mag_min': np.min(magnitudes),
                'error_mag_max': np.max(magnitudes),
                'error_mag_mean': np.mean(magnitudes),
                'error_mag_std': np.std(magnitudes),
                'variation_ratio': np.max(magnitudes) / np.min(magnitudes),
                'correlation_l1': np.corrcoef(l1_norms, magnitudes)[0, 1]
            }
        }
    return all_results


# ============================================================
# Manifold visualization
# ============================================================

def plot_manifold_comparison(all_results, scale=None):
    """Compare error patterns across manifolds."""
    n_manifolds = len(all_results)
    n_cols = 3
    n_rows = (n_manifolds + n_cols - 1) // n_cols

    if scale is None:
        scale = max(np.abs(d['points']).max() for d in all_results.values()) * 1.4

    # Figure 1: Manifolds colored by error
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
    axes = axes.flatten()

    for idx, (name, data) in enumerate(all_results.items()):
        ax = axes[idx]
        points = data['points']
        magnitudes = [r['error_magnitude'] for r in data['results']]

        scatter = ax.scatter(points[:, 0], points[:, 1],
                            c=magnitudes, cmap='hot', s=60, edgecolors='black', linewidth=0.5)
        plt.colorbar(scatter, ax=ax, label='Error mag')

        if data['metadata']['type'] in ('closed', 'open'):
            conn = np.vstack([points, points[0]]) if data['metadata']['type'] == 'closed' else points
            ax.plot(conn[:, 0], conn[:, 1], 'b-', alpha=0.3, linewidth=1)

        stats = data['stats']
        ax.set_title(f"{name}\nVar ratio: {stats['variation_ratio']:.2f}x, "
                    f"Corr(L1): {stats['correlation_l1']:.2f}")
        set_fixed_scale(ax, scale)

    for idx in range(len(all_results), len(axes)):
        axes[idx].set_visible(False)

    plt.tight_layout()
    plt.savefig('plots/manifolds_error_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Figure 2: Summary statistics
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    names = list(all_results.keys())
    x_pos = np.arange(len(names))

    # Error range
    ax = axes[0]
    maxs = [all_results[n]['stats']['error_mag_max'] for n in names]
    means = [all_results[n]['stats']['error_mag_mean'] for n in names]
    mins = [all_results[n]['stats']['error_mag_min'] for n in names]
    ax.bar(x_pos, maxs, alpha=0.3, color='red', label='Max')
    ax.bar(x_pos, means, alpha=0.5, color='blue', label='Mean')
    ax.bar(x_pos, mins, alpha=0.7, color='green', label='Min')
    ax.set_xticks(x_pos); ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Error magnitude')
    ax.set_title('Error range by manifold'); ax.legend(); ax.grid(True, alpha=0.3)

    # Variation ratio
    ax = axes[1]
    ratios = [all_results[n]['stats']['variation_ratio'] for n in names]
    ax.bar(x_pos, ratios, color='purple', alpha=0.7)
    ax.set_xticks(x_pos); ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Max/Min error ratio')
    ax.set_title('Error variation within manifold'); ax.axhline(1.0, color='gray', linestyle='--')
    ax.grid(True, alpha=0.3)

    # Correlation with L1 norm
    ax = axes[2]
    corrs = [all_results[n]['stats']['correlation_l1'] for n in names]
    colors = ['green' if c > 0.8 else 'orange' if c > 0.5 else 'red' for c in corrs]
    ax.bar(x_pos, corrs, color=colors, alpha=0.7)
    ax.set_xticks(x_pos); ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Correlation')
    ax.set_title('Error vs L1 norm correlation\n(Green=predictable, Red=complex)')
    ax.set_ylim(0, 1.1); ax.axhline(1.0, color='gray', linestyle='--'); ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('plots/manifolds_statistics.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_grid_heatmap(quant_weights, delta=DELTA, extent=30, n_side=15):
    """Error magnitude heatmap across input space."""
    grid_1d = np.linspace(-extent, extent, n_side)
    grid_x, grid_y = np.meshgrid(grid_1d, grid_1d)
    grid_points = np.column_stack([grid_x.ravel(), grid_y.ravel()])

    results = compute_manifold_errors(grid_points, quant_weights, delta)
    errors = np.array([r['error_magnitude'] for r in results]).reshape(n_side, n_side)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Heatmap
    ax = axes[0]
    im = ax.imshow(errors, extent=[-extent, extent, -extent, extent], origin='lower', cmap='hot')
    plt.colorbar(im, ax=ax, label='Error magnitude')
    ax.set_xlabel('Input dim 0'); ax.set_ylabel('Input dim 1')
    ax.set_title('Error magnitude across input space')
    ax.set_aspect('equal')

    # Error vs L1 norm
    ax = axes[1]
    l1_norms = np.sum(np.abs(grid_points), axis=1)
    error_flat = errors.ravel()
    ax.scatter(l1_norms, error_flat, alpha=0.5, s=20,
              c=np.arctan2(grid_points[:, 1], grid_points[:, 0]), cmap='hsv')
    z = np.polyfit(l1_norms, error_flat, 1)
    x_fit = np.linspace(l1_norms.min(), l1_norms.max(), 100)
    ax.plot(x_fit, np.poly1d(z)(x_fit), 'r-', linewidth=2,
           label=f'Linear fit: y={z[0]:.4f}x + {z[1]:.4f}')
    ax.set_xlabel('L1 norm of input'); ax.set_ylabel('Error magnitude')
    ax.set_title('Error scales with L1 norm\n(color=angle, shows directional dependence)')
    ax.legend(); ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('plots/manifold_grid_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_error_manifolds(true_weights, quant_weights, manifold_names=None, n_points=64):
    """Plot input manifold vs error manifold side by side.

    Since the network is linear (no activations), the error is:
        error(x) = (Q_n...Q_1 - W_n...W_1) @ x
    This is a linear map, so circles become ellipses, lines stay lines, etc.

    Color encodes position along the manifold so you can see which input
    point maps to which error point.
    """
    if manifold_names is None:
        manifold_names = ['circle', 'ellipse', 'line', 'spiral', 'figure_eight']

    n = len(manifold_names)
    fig, axes = plt.subplots(n, 2, figsize=(12, 4 * n))

    for i, name in enumerate(manifold_names):
        points, metadata = make_manifold(name, n_points=n_points)
        errors, W_error = compute_pointwise_errors(points, true_weights, quant_weights)

        # Color by position along manifold
        t = np.linspace(0, 1, len(points))

        # --- Input manifold ---
        ax = axes[i, 0]
        ax.scatter(points[:, 0], points[:, 1], c=t, cmap='viridis',
                   s=30, edgecolors='black', linewidth=0.5)
        if metadata['type'] in ('closed', 'open'):
            conn = np.vstack([points, points[0]]) if metadata['type'] == 'closed' else points
            ax.plot(conn[:, 0], conn[:, 1], 'k-', alpha=0.3, linewidth=1)
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)
        ax.axhline(0, color='k', linewidth=0.5); ax.axvline(0, color='k', linewidth=0.5)
        ax.set_title(f'{name} — input')
        if i == n - 1:
            ax.set_xlabel('Dim 0')
        ax.set_ylabel('Dim 1')

        # --- Error manifold ---
        ax = axes[i, 1]
        ax.scatter(errors[:, 0], errors[:, 1], c=t, cmap='viridis',
                   s=30, edgecolors='black', linewidth=0.5)
        if metadata['type'] in ('closed', 'open'):
            conn_e = np.vstack([errors, errors[0]]) if metadata['type'] == 'closed' else errors
            ax.plot(conn_e[:, 0], conn_e[:, 1], 'k-', alpha=0.3, linewidth=1)
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)
        ax.axhline(0, color='k', linewidth=0.5); ax.axvline(0, color='k', linewidth=0.5)
        error_mag = np.linalg.norm(errors, axis=1)
        ax.set_title(f'{name} — error (max |e|={error_mag.max():.4f})')
        if i == n - 1:
            ax.set_xlabel('Error dim 0')
        ax.set_ylabel('Error dim 1')

    plt.suptitle('Input Manifold  →  Error Manifold\n'
                 'Error is a linear map of input: circles→ellipses, lines→lines',
                 fontsize=13, y=1.02)
    plt.tight_layout()
    plt.savefig('plots/error_manifolds.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print the error transform properties
    _, W_error = compute_pointwise_errors(np.eye(2), true_weights, quant_weights)
    U, S, Vt = np.linalg.svd(W_error)
    print(f"\nError transform matrix (Q_product - W_product):")
    print(f"  [[{W_error[0,0]:.6f}, {W_error[0,1]:.6f}],")
    print(f"   [{W_error[1,0]:.6f}, {W_error[1,1]:.6f}]]")
    print(f"  Singular values: {S[0]:.6f}, {S[1]:.6f}")
    print(f"  Condition number: {S[0]/S[1]:.2f}")
    print(f"  → Circle of radius r maps to ellipse with semi-axes "
          f"{S[0]:.6f}r x {S[1]:.6f}r")
    print(f"  → Max stretch direction: [{Vt[0,0]:.3f}, {Vt[0,1]:.3f}]")
    print(f"  → Min stretch direction: [{Vt[1,0]:.3f}, {Vt[1,1]:.3f}]")


# ============================================================
# Run manifold analysis
# ============================================================

# Use weights from experiment 3
manifold_weights = [
    np.array([[0.9, 0.2], [0.1, 1.0]]),
    np.array([[0.95, -0.15], [0.2, 0.85]]),
    np.array([[1.0, 0.1], [-0.1, 0.9]]),
    np.array([[0.85, 0.15], [0.1, 1.05]]),
]
manifold_qw = [quantize(W) for W in manifold_weights]

print("=" * 70)
print("MANIFOLD COMPARISON")
print("=" * 70)

all_manifold_results = run_all_manifolds(manifold_qw, n_points=48)

print(f"\n{'Manifold':<15} {'Min Error':<12} {'Max Error':<12} {'Var Ratio':<12} {'Corr(L1)':<10}")
print("-" * 60)
for name, data in all_manifold_results.items():
    s = data['stats']
    print(f"{name:<15} {s['error_mag_min']:<12.4f} {s['error_mag_max']:<12.4f} "
          f"{s['variation_ratio']:<12.2f} {s['correlation_l1']:<10.2f}")

plot_manifold_comparison(all_manifold_results)
plot_grid_heatmap(manifold_qw)

# Error manifold visualization: input shape → error shape
print("\n" + "=" * 70)
print("ERROR MANIFOLDS: Input Shape → Error Shape")
print("=" * 70)
plot_error_manifolds(manifold_weights, manifold_qw)

print("\nOBSERVATIONS:")
print("1. CIRCLE → ELLIPSE: the linear error transform stretches/rotates")
print("2. LINE → LINE: at a different angle and length")
print("3. SPIRAL → SPIRAL: transformed but still spiral-shaped")
print("4. The error transform is a single 2x2 matrix applied to all inputs")
print("5. SVD of the error matrix tells you: how much stretch, in what direction")
print("6. High condition number = error is much worse in some directions than others")

In [None]:
# ============================================================
# FAKE QUANTIZATION — LAYER-BY-LAYER
# ============================================================
#
# Proper fake quantization as used in QAT and integer inference:
#
#   Per layer:
#     1. Weights are fake-quantized (per-tensor, static)
#     2. Matmul: input @ W_fq.T  (simulates int8×int8 → int32 accumulator, which is exact)
#     3. Output is fake-quantized (per-tensor across all activations — simulates
#        requantization from int32 accumulator back to int8)
#
# Key: activation quantization uses ONE scale for the entire activation
# tensor (per-tensor), not per-vector. All points share the same
# rounding grid at each layer.


def fake_quantize(x, bits=8):
    """Symmetric fake quantization: float32 → int{bits} → float32.

    scale = max(|x|) / (2^(bits-1) - 1)
    q = clamp(round(x / scale), -2^(bits-1), 2^(bits-1) - 1)
    output = q * scale
    """
    qmax = (1 << (bits - 1)) - 1
    qmin = -(1 << (bits - 1))
    abs_max = np.abs(x).max()
    if abs_max < 1e-10:
        return x.copy()
    scale = abs_max / qmax
    q = np.clip(np.round(x / scale), qmin, qmax)
    return q * scale


def run_manifold_fq(points, weight_matrices, bits=8, quantize_activations=True):
    """Run manifold through network with fake quantization.

    Simulates real integer inference:
      int8_input × int8_weight → int32_accumulator → requantize → int8_output

    The int32 accumulator means the matmul is exact — the only rounding
    happens at weight quantization and output requantization.

    Activation quantization is per-tensor: one scale for the entire
    (n_points, dim) activation matrix at each layer.

    Returns:
        float_acts: activations at each layer boundary (float path)
        quant_acts: activations at each layer boundary (quantized path)
        errors: quant - float at each layer boundary
    """
    quant_weights = [fake_quantize(W, bits=bits) for W in weight_matrices]

    float_acts = [points.copy()]
    quant_acts = [points.copy()]

    val_float = points.copy()
    val_quant = points.copy()

    for W_f, W_q in zip(weight_matrices, quant_weights):
        # Float path: exact matmul
        val_float = val_float @ W_f.T
        float_acts.append(val_float.copy())

        # Quantized path: matmul with quantized weights (int32 accumulator = exact)
        val_quant = val_quant @ W_q.T
        if quantize_activations:
            # Requantize output: per-tensor (single scale for all points)
            val_quant = fake_quantize(val_quant, bits=bits)
        quant_acts.append(val_quant.copy())

    errors = [q - f for f, q in zip(float_acts, quant_acts)]
    return float_acts, quant_acts, errors


# ============================================================
# Visualization: weight-only vs full quantization comparison
# ============================================================

def plot_weight_vs_full_comparison(weight_matrices, bits=8,
                                   manifold_name='circle', n_points=128):
    """Side-by-side: weight-only errors vs weight+activation errors."""
    points, metadata = make_manifold(manifold_name, n_points=n_points)
    t = np.linspace(0, 1, len(points))

    _, _, errors_w = run_manifold_fq(
        points, weight_matrices, bits, quantize_activations=False
    )
    _, _, errors_full = run_manifold_fq(
        points, weight_matrices, bits, quantize_activations=True
    )

    n_layers = len(weight_matrices)
    fig, axes = plt.subplots(2, n_layers + 1,
                              figsize=(3.5 * (n_layers + 1), 7))

    for row, (errors, label) in enumerate([
        (errors_w, 'Weight-only'),
        (errors_full, 'Weight + activation'),
    ]):
        ax = axes[row, 0]
        ax.scatter(points[:, 0], points[:, 1], c=t, cmap='viridis',
                   s=15, edgecolors='none')
        is_closed = metadata['type'] == 'closed'
        is_connected = metadata['type'] in ('closed', 'open')
        if is_connected:
            conn = np.vstack([points, points[0]]) if is_closed else points
            ax.plot(conn[:, 0], conn[:, 1], 'k-', alpha=0.15, linewidth=0.5)
        ax.set_aspect('equal'); ax.grid(True, alpha=0.2)
        ax.axhline(0, color='k', linewidth=0.3)
        ax.axvline(0, color='k', linewidth=0.3)
        ax.set_ylabel(label, fontsize=10, fontweight='bold')
        if row == 0:
            ax.set_title('Input', fontweight='bold')

        for col in range(n_layers):
            ax = axes[row, col + 1]
            err = errors[col + 1]
            ax.scatter(err[:, 0], err[:, 1], c=t, cmap='viridis',
                       s=15, edgecolors='none')
            if is_connected:
                conn_e = np.vstack([err, err[0]]) if is_closed else err
                ax.plot(conn_e[:, 0], conn_e[:, 1], 'k-',
                        alpha=0.15, linewidth=0.5)
            ax.set_aspect('equal'); ax.grid(True, alpha=0.2)
            ax.axhline(0, color='k', linewidth=0.3)
            ax.axvline(0, color='k', linewidth=0.3)
            if row == 0:
                ax.set_title(f'Error after L{col+1}', fontweight='bold')
            err_mag = np.linalg.norm(err, axis=1)
            ax.text(0.95, 0.95, f'max|e|={err_mag.max():.4f}',
                    transform=ax.transAxes, ha='right', va='top',
                    fontsize=7, color='red',
                    bbox=dict(boxstyle='round,pad=0.3',
                              facecolor='white', alpha=0.8))

    plt.suptitle(
        f'{bits}-bit: Weight-only (top) vs Weight+Activation (bottom)\n'
        f'Activation requantization adds rounding at grid boundaries',
        fontsize=12, y=1.04
    )
    plt.tight_layout()
    plt.savefig(f'plots/fq{bits}_weight_vs_full.png', dpi=150,
                bbox_inches='tight')
    plt.show()


# ============================================================
# Visualization: layer-by-layer error evolution
# ============================================================

def plot_fq_error_evolution(weight_matrices, bits=8,
                            manifold_names=None, n_points=128):
    """Grid: rows = manifolds, cols = input + error after each layer."""
    if manifold_names is None:
        manifold_names = ['circle', 'ellipse', 'line', 'spiral', 'figure_eight']

    n_manifolds = len(manifold_names)
    n_layers = len(weight_matrices)
    n_cols = n_layers + 1

    fig, axes = plt.subplots(n_manifolds, n_cols,
                              figsize=(3.5 * n_cols, 3.5 * n_manifolds))

    for row, name in enumerate(manifold_names):
        points, metadata = make_manifold(name, n_points=n_points)
        float_acts, quant_acts, errors = run_manifold_fq(
            points, weight_matrices, bits
        )

        t = np.linspace(0, 1, len(points))
        is_connected = metadata['type'] in ('closed', 'open')
        is_closed = metadata['type'] == 'closed'

        # Input manifold
        ax = axes[row, 0]
        ax.scatter(points[:, 0], points[:, 1], c=t, cmap='viridis',
                   s=12, edgecolors='none')
        if is_connected:
            conn = np.vstack([points, points[0]]) if is_closed else points
            ax.plot(conn[:, 0], conn[:, 1], 'k-', alpha=0.15, linewidth=0.5)
        ax.set_aspect('equal'); ax.grid(True, alpha=0.2)
        ax.axhline(0, color='k', linewidth=0.3)
        ax.axvline(0, color='k', linewidth=0.3)
        if row == 0:
            ax.set_title('Input', fontweight='bold')
        ax.set_ylabel(name, fontsize=11, fontweight='bold')

        # Error at each layer
        for col in range(n_layers):
            ax = axes[row, col + 1]
            err = errors[col + 1]
            ax.scatter(err[:, 0], err[:, 1], c=t, cmap='viridis',
                       s=12, edgecolors='none')
            if is_connected:
                conn_e = np.vstack([err, err[0]]) if is_closed else err
                ax.plot(conn_e[:, 0], conn_e[:, 1], 'k-',
                        alpha=0.15, linewidth=0.5)
            ax.set_aspect('equal'); ax.grid(True, alpha=0.2)
            ax.axhline(0, color='k', linewidth=0.3)
            ax.axvline(0, color='k', linewidth=0.3)
            if row == 0:
                ax.set_title(f'Error after L{col+1}', fontweight='bold')
            if col == n_layers - 1:
                err_mag = np.linalg.norm(err, axis=1)
                ax.text(0.95, 0.95, f'max|e|={err_mag.max():.4f}',
                        transform=ax.transAxes, ha='right', va='top',
                        fontsize=7, color='red',
                        bbox=dict(boxstyle='round,pad=0.3',
                                  facecolor='white', alpha=0.8))

    plt.suptitle(f'{bits}-bit Fake Quantization — Error Manifold Evolution\n'
                 f'Per-tensor weight + activation quantization',
                 fontsize=13, y=1.03)
    plt.tight_layout()
    plt.savefig(f'plots/fq{bits}_error_evolution.png', dpi=150,
                bbox_inches='tight')
    plt.show()


def plot_fq_error_growth(weight_matrices, bits=8,
                         manifold_names=None, n_points=128):
    """Error magnitude growth across layers for each manifold."""
    if manifold_names is None:
        manifold_names = ['circle', 'ellipse', 'line', 'spiral', 'figure_eight']

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    for name in manifold_names:
        points, _ = make_manifold(name, n_points=n_points)
        _, _, errors = run_manifold_fq(points, weight_matrices, bits)

        max_errors = [np.linalg.norm(err, axis=1).max() for err in errors]
        mean_errors = [np.linalg.norm(err, axis=1).mean() for err in errors]
        layers = list(range(len(errors)))

        axes[0].plot(layers, max_errors, 'o-', linewidth=2,
                     markersize=6, label=name)
        axes[1].plot(layers, mean_errors, 'o-', linewidth=2,
                     markersize=6, label=name)

    for ax, title in zip(axes, ['Max |error|', 'Mean |error|']):
        ax.set_xlabel('Layer')
        ax.set_ylabel(title)
        ax.set_title(f'{title} growth through layers')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.set_xticks(range(len(weight_matrices) + 1))
        ax.set_xticklabels(['input'] + [f'L{i+1}' for i in range(len(weight_matrices))])

    plt.suptitle(f'{bits}-bit Error Accumulation', fontsize=13)
    plt.tight_layout()
    plt.savefig(f'plots/fq{bits}_error_growth.png', dpi=150,
                bbox_inches='tight')
    plt.show()


def plot_bits_comparison(weight_matrices, manifold_name='circle',
                         bits_list=None, n_points=128):
    """Compare error manifolds at different bit widths for one manifold."""
    if bits_list is None:
        bits_list = [2, 4, 8]

    points, metadata = make_manifold(manifold_name, n_points=n_points)
    t = np.linspace(0, 1, len(points))
    is_closed = metadata['type'] == 'closed'
    is_connected = metadata['type'] in ('closed', 'open')

    n_bits = len(bits_list)
    n_layers = len(weight_matrices)

    fig, axes = plt.subplots(n_bits, n_layers + 1,
                              figsize=(3.5 * (n_layers + 1), 3.5 * n_bits))

    for row, bits in enumerate(bits_list):
        _, _, errors = run_manifold_fq(points, weight_matrices, bits)

        ax = axes[row, 0]
        ax.scatter(points[:, 0], points[:, 1], c=t, cmap='viridis',
                   s=12, edgecolors='none')
        if is_connected:
            conn = np.vstack([points, points[0]]) if is_closed else points
            ax.plot(conn[:, 0], conn[:, 1], 'k-', alpha=0.15, linewidth=0.5)
        ax.set_aspect('equal'); ax.grid(True, alpha=0.2)
        ax.axhline(0, color='k', linewidth=0.3)
        ax.axvline(0, color='k', linewidth=0.3)
        ax.set_ylabel(f'{bits}-bit', fontsize=12, fontweight='bold')
        if row == 0:
            ax.set_title('Input', fontweight='bold')

        for col in range(n_layers):
            ax = axes[row, col + 1]
            err = errors[col + 1]
            ax.scatter(err[:, 0], err[:, 1], c=t, cmap='viridis',
                       s=12, edgecolors='none')
            if is_connected:
                conn_e = np.vstack([err, err[0]]) if is_closed else err
                ax.plot(conn_e[:, 0], conn_e[:, 1], 'k-',
                        alpha=0.15, linewidth=0.5)
            ax.set_aspect('equal'); ax.grid(True, alpha=0.2)
            ax.axhline(0, color='k', linewidth=0.3)
            ax.axvline(0, color='k', linewidth=0.3)
            if row == 0:
                ax.set_title(f'Error after L{col+1}', fontweight='bold')
            err_mag = np.linalg.norm(err, axis=1)
            ax.text(0.95, 0.95, f'|e|={err_mag.max():.4f}',
                    transform=ax.transAxes, ha='right', va='top',
                    fontsize=7, color='red',
                    bbox=dict(boxstyle='round,pad=0.3',
                              facecolor='white', alpha=0.8))

    plt.suptitle(f'Bit Width Comparison: {manifold_name}\n'
                 f'Lower bits → larger errors, more rounding distortion',
                 fontsize=13, y=1.03)
    plt.tight_layout()
    plt.savefig('plots/fq_bits_comparison.png', dpi=150,
                bbox_inches='tight')
    plt.show()


# ============================================================
# Run
# ============================================================

fq_weights = [
    np.array([[0.9, 0.2], [0.1, 1.0]]),
    np.array([[0.95, -0.15], [0.2, 0.85]]),
    np.array([[1.0, 0.1], [-0.1, 0.9]]),
    np.array([[0.85, 0.15], [0.1, 1.05]]),
]

print("=" * 70)
print("FAKE QUANTIZATION — LAYER-BY-LAYER ERROR EVOLUTION")
print("=" * 70)
print()
print("Pipeline per layer:")
print("  int8_input × int8_weight → int32_accumulator (exact) → requantize → int8_output")
print("  Activation quantization: per-tensor (one scale for all points)")
print()

for i, W in enumerate(fq_weights):
    W_q = fake_quantize(W, bits=8)
    print(f"  Layer {i+1} weight max quant error: {np.abs(W - W_q).max():.6f}")

# Weight-only vs full comparison
print("\n--- Weight-only vs weight+activation (8-bit, circle) ---")
plot_weight_vs_full_comparison(fq_weights, bits=8,
                                manifold_name='circle', n_points=128)

# Full error evolution at 8-bit
plot_fq_error_evolution(fq_weights, bits=8, n_points=128)
plot_fq_error_growth(fq_weights, bits=8, n_points=128)

# Bit-width comparison
print("\nBit-width comparison on circle manifold:")
plot_bits_comparison(fq_weights, manifold_name='circle',
                     bits_list=[2, 4, 8], n_points=128)