In [None]:
# ============================================================
# Quantization Error Geometry: A Visual Exploration
# ============================================================
#
# This notebook traces how quantization error regions evolve through
# neural network layers, building intuition from simple to complex cases.
#
# All experiments use 2D for visualization clarity.
# Scales are fixed across plots for fair comparison.
#
# Experiments:
# 1. Uniform diagonal weights (baseline)
# 2. Non-uniform diagonal weights (per-channel variation)
# 3. Full matrices (channel mixing, rotation/shear)
# 4. Multiple input points (error manifold)

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
from dataclasses import dataclass, field
from typing import List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# Configuration
# ============================================================

# Global settings
BITS = 8
DELTA = 1.0 / (2 ** (BITS - 1))
N_LAYERS = 4

# Fixed scale for all error plots (set after computing max errors)
GLOBAL_ERROR_SCALE = None  # Will be set after experiments

# Colors
COLORS = {
    'layer1': '#1f77b4',
    'layer2': '#ff7f0e', 
    'layer3': '#2ca02c',
    'layer4': '#d62728',
    'cumulative': '#e377c2',
    'input': '#17becf',
    'error_region': '#ff6b6b',
    'reference': '#888888'
}


# ============================================================
# Data structures for storing results
# ============================================================

@dataclass
class LayerStats:
    """Statistics for a single layer"""
    layer_idx: int
    weight_matrix: np.ndarray
    spectral_norm: float
    determinant: float
    condition_number: float
    error_half_widths: np.ndarray  # In input space
    error_volume: float  # Area in 2D
    
@dataclass 
class ExperimentStats:
    """Statistics for a full experiment"""
    name: str
    input_point: np.ndarray
    layer_stats: List[LayerStats]
    cumulative_error_vertices: np.ndarray
    cumulative_error_volume: float
    bounding_box: np.ndarray  # [min, max] for each dim
    relative_error: np.ndarray  # Per channel
    
    def summary(self):
        return {
            'name': self.name,
            'input': self.input_point.tolist(),
            'final_volume': self.cumulative_error_volume,
            'bbox': self.bounding_box.tolist(),
            'relative_error': self.relative_error.tolist(),
            'spectral_norms': [ls.spectral_norm for ls in self.layer_stats],
        }


@dataclass
class AllExperimentStats:
    """Container for all experiment results"""
    experiments: Dict[str, ExperimentStats] = field(default_factory=dict)
    
    def add(self, stats: ExperimentStats):
        self.experiments[stats.name] = stats
    
    def print_summary(self):
        print("\n" + "=" * 70)
        print("SUMMARY OF ALL EXPERIMENTS")
        print("=" * 70)
        for name, stats in self.experiments.items():
            print(f"\n{name}:")
            print(f"  Input: {stats.input_point}")
            print(f"  Final error volume: {stats.cumulative_error_volume:.6f}")
            print(f"  Bounding box: {stats.bounding_box}")
            print(f"  Relative error: {stats.relative_error}")


# Global stats container
ALL_STATS = AllExperimentStats()


# ============================================================
# Core functions
# ============================================================

def quantize(W, delta=DELTA):
    """Quantize matrix to nearest grid point"""
    return np.round(W / delta) * delta


def get_box_vertices_2d(half_widths):
    """Get vertices of 2D box centered at origin"""
    hw = np.array(half_widths)
    return np.array([
        [-hw[0], -hw[1]],
        [-hw[0],  hw[1]],
        [ hw[0],  hw[1]],
        [ hw[0], -hw[1]]
    ])


def minkowski_sum_2d(V1, V2):
    """Minkowski sum of two 2D vertex sets"""
    sums = []
    for v1 in V1:
        for v2 in V2:
            sums.append(v1 + v2)
    sums = np.array(sums)
    
    if len(sums) >= 3:
        try:
            hull = ConvexHull(sums)
            return sums[hull.vertices]
        except:
            pass
    return sums


def compute_polygon_area(vertices):
    """Compute area of polygon using shoelace formula"""
    if len(vertices) < 3:
        return 0.0
    try:
        hull = ConvexHull(vertices)
        return hull.volume  # In 2D, 'volume' is area
    except:
        return 0.0


def transform_vertices(vertices, W):
    """Apply linear transformation W to vertices"""
    return vertices @ W.T


def draw_polygon(ax, vertices, color, alpha=0.3, edgecolor=None, linewidth=2, label=None):
    """Draw a polygon from vertices"""
    if len(vertices) < 3:
        ax.scatter(vertices[:, 0], vertices[:, 1], c=color, s=50, label=label)
        return
    
    try:
        hull = ConvexHull(vertices)
        hull_verts = vertices[hull.vertices]
        hull_verts = np.vstack([hull_verts, hull_verts[0]])  # Close polygon
        
        ax.fill(hull_verts[:, 0], hull_verts[:, 1], color=color, alpha=alpha, label=label)
        ax.plot(hull_verts[:, 0], hull_verts[:, 1], color=edgecolor or color, linewidth=linewidth)
    except:
        ax.scatter(vertices[:, 0], vertices[:, 1], c=color, s=50, label=label)


def set_fixed_scale(ax, scale, center=(0, 0)):
    """Set fixed axis limits"""
    ax.set_xlim(center[0] - scale, center[0] + scale)
    ax.set_ylim(center[1] - scale, center[1] + scale)
    ax.set_aspect('equal')
    ax.axhline(0, color='k', linewidth=0.5)
    ax.axvline(0, color='k', linewidth=0.5)
    ax.grid(True, alpha=0.3)


# ============================================================
# Experiment runner
# ============================================================

def run_experiment(name, x_input, weight_matrices, compute_error_fn):
    """
    Run an experiment and collect statistics.
    
    Args:
        name: Experiment name
        x_input: Input point (2D)
        weight_matrices: List of weight matrices (true, pre-quantization)
        compute_error_fn: Function to compute error vertices at each layer
    
    Returns:
        ExperimentStats object
    """
    quant_weights = [quantize(W) for W in weight_matrices]
    
    layer_stats = []
    val = x_input.copy()
    cumulative_W = np.eye(2)
    cumulative_error_vertices = None
    
    for i, (W_true, W) in enumerate(zip(weight_matrices, quant_weights)):
        # Compute layer statistics
        spectral_norm = np.linalg.norm(W, ord=2)
        det = np.linalg.det(W)
        svd = np.linalg.svd(W, compute_uv=False)
        cond = svd.max() / svd.min() if svd.min() > 0 else np.inf
        
        # Compute error vertices for this layer
        local_error_vertices = compute_error_fn(val, W, DELTA)
        
        # Map to input space
        cumulative_W_after = W @ cumulative_W
        try:
            inv_W = np.linalg.inv(cumulative_W_after)
            error_vertices_input = transform_vertices(local_error_vertices, inv_W)
        except:
            error_vertices_input = local_error_vertices
        
        # Minkowski sum
        if cumulative_error_vertices is None:
            cumulative_error_vertices = error_vertices_input
        else:
            cumulative_error_vertices = minkowski_sum_2d(cumulative_error_vertices, error_vertices_input)
        
        # Compute half-widths (bounding box of this layer's contribution)
        hw = np.abs(error_vertices_input).max(axis=0)
        
        layer_stats.append(LayerStats(
            layer_idx=i,
            weight_matrix=W.copy(),
            spectral_norm=spectral_norm,
            determinant=det,
            condition_number=cond,
            error_half_widths=hw,
            error_volume=compute_polygon_area(error_vertices_input)
        ))
        
        # Update for next layer
        val = W @ val
        cumulative_W = cumulative_W_after
    
    # Final statistics
    bbox_min = cumulative_error_vertices.min(axis=0)
    bbox_max = cumulative_error_vertices.max(axis=0)
    bbox = np.array([bbox_max - bbox_min])  # Full width
    
    rel_error = (bbox_max - bbox_min) / (2 * np.abs(x_input) + 1e-10)
    
    stats = ExperimentStats(
        name=name,
        input_point=x_input.copy(),
        layer_stats=layer_stats,
        cumulative_error_vertices=cumulative_error_vertices,
        cumulative_error_volume=compute_polygon_area(cumulative_error_vertices),
        bounding_box=np.array([bbox_min, bbox_max]),
        relative_error=rel_error
    )
    
    ALL_STATS.add(stats)
    return stats, quant_weights


# ============================================================
# EXPERIMENT 1: Uniform Diagonal Weights
# ============================================================

def exp1_error_fn(val, W, delta):
    """Error function for diagonal weights - produces axis-aligned box"""
    # For diagonal W, error in each dim is independent
    # error[i] = W_err[i,i] * val[i], W_err in [-delta/2, delta/2]
    hw = (delta / 2) * np.abs(val)
    return get_box_vertices_2d(hw)


def run_experiment_1(x_input):
    """Experiment 1: Uniform diagonal weights"""
    print("\n" + "=" * 70)
    print("EXPERIMENT 1: Uniform Diagonal Weights")
    print("=" * 70)
    
    # All layers have same weight on diagonal
    weights = [
        np.eye(2) * 0.9,
        np.eye(2) * 1.1,
        np.eye(2) * 0.85,
        np.eye(2) * 1.05,
    ]
    
    stats, quant_weights = run_experiment(
        "Exp1: Uniform Diagonal",
        x_input,
        weights,
        exp1_error_fn
    )
    
    print(f"Input: {x_input}")
    print(f"Final error volume: {stats.cumulative_error_volume:.6f}")
    print(f"Relative error: {stats.relative_error}")
    
    return stats, quant_weights


# ============================================================
# EXPERIMENT 2: Non-Uniform Diagonal Weights  
# ============================================================

def exp2_error_fn(val, W, delta):
    """Error function for diagonal weights - same as exp1"""
    hw = (delta / 2) * np.abs(val)
    return get_box_vertices_2d(hw)


def run_experiment_2(x_input):
    """Experiment 2: Non-uniform diagonal weights"""
    print("\n" + "=" * 70)
    print("EXPERIMENT 2: Non-Uniform Diagonal Weights")
    print("=" * 70)
    
    # Different weights per channel
    weights = [
        np.diag([0.8, 1.2]),   # Ch1 amplified
        np.diag([1.1, 0.7]),   # Ch0 amplified, Ch1 shrunk
        np.diag([0.9, 1.1]),   # Mild
        np.diag([1.2, 0.8]),   # Ch0 amplified
    ]
    
    stats, quant_weights = run_experiment(
        "Exp2: Non-Uniform Diagonal",
        x_input,
        weights,
        exp2_error_fn
    )
    
    print(f"Input: {x_input}")
    print(f"Final error volume: {stats.cumulative_error_volume:.6f}")
    print(f"Relative error: {stats.relative_error}")
    
    return stats, quant_weights


# ============================================================
# EXPERIMENT 3: Full Matrices (Non-Diagonal)
# ============================================================

def exp3_error_fn(val, W, delta):
    """
    Error function for full matrices.
    
    For full W, output error = W_err @ val where each W_err[i,j] is independent.
    Output_err[i] = sum_j W_err[i,j] * val[j]
    
    This is a sum of independent terms, each in [-delta/2 * |val[j]|, delta/2 * |val[j]|].
    The result is an axis-aligned box in output space.
    """
    # Each output dim has error from all input dims
    # hw[i] = delta/2 * sum_j |val[j]| = delta/2 * L1_norm(val)
    l1_norm = np.sum(np.abs(val))
    hw = (delta / 2) * l1_norm * np.ones(2)
    return get_box_vertices_2d(hw)


def run_experiment_3(x_input):
    """Experiment 3: Full matrices with off-diagonal elements"""
    print("\n" + "=" * 70)
    print("EXPERIMENT 3: Full Matrices (Non-Diagonal)")
    print("=" * 70)
    
    # Matrices with rotation/shear
    weights = [
        np.array([[0.9, 0.2],
                  [0.1, 1.0]]),
        np.array([[0.95, -0.15],
                  [0.2, 0.85]]),
        np.array([[1.0, 0.1],
                  [-0.1, 0.9]]),
        np.array([[0.85, 0.15],
                  [0.1, 1.05]]),
    ]
    
    stats, quant_weights = run_experiment(
        "Exp3: Full Matrices",
        x_input,
        weights,
        exp3_error_fn
    )
    
    print(f"Input: {x_input}")
    print(f"Final error volume: {stats.cumulative_error_volume:.6f}")
    print(f"Relative error: {stats.relative_error}")
    
    # Additional: compute cumulative transform properties
    cumulative_W = np.eye(2)
    for W in quant_weights:
        cumulative_W = W @ cumulative_W
    
    U, S, Vt = np.linalg.svd(cumulative_W)
    print(f"Cumulative transform singular values: {S}")
    print(f"Cumulative transform condition number: {S.max()/S.min():.3f}")
    
    return stats, quant_weights


# ============================================================
# EXPERIMENT 4: Multiple Input Points
# ============================================================

def run_experiment_4(base_weights):
    """Experiment 4: Error manifold - multiple input points"""
    print("\n" + "=" * 70)
    print("EXPERIMENT 4: Multiple Input Points (Error Manifold)")
    print("=" * 70)
    
    quant_weights = [quantize(W) for W in base_weights]
    
    # Generate input manifold: circle
    n_points = 32
    theta = np.linspace(0, 2*np.pi, n_points, endpoint=False)
    radius = 20
    circle_points = np.column_stack([radius * np.cos(theta), radius * np.sin(theta)])
    
    # Compute error for each point
    results = []
    for x in circle_points:
        val = x.copy()
        cumulative_W = np.eye(2)
        cumulative_error_vertices = None
        
        for W in quant_weights:
            local_error_vertices = exp3_error_fn(val, W, DELTA)
            cumulative_W_after = W @ cumulative_W
            
            try:
                inv_W = np.linalg.inv(cumulative_W_after)
                error_vertices_input = transform_vertices(local_error_vertices, inv_W)
            except:
                error_vertices_input = local_error_vertices
            
            if cumulative_error_vertices is None:
                cumulative_error_vertices = error_vertices_input
            else:
                cumulative_error_vertices = minkowski_sum_2d(cumulative_error_vertices, error_vertices_input)
            
            val = W @ val
            cumulative_W = cumulative_W_after
        
        error_magnitude = np.max(np.linalg.norm(cumulative_error_vertices, axis=1))
        results.append({
            'input': x.copy(),
            'error_vertices': cumulative_error_vertices.copy(),
            'error_magnitude': error_magnitude,
            'error_volume': compute_polygon_area(cumulative_error_vertices)
        })
    
    # Statistics
    magnitudes = [r['error_magnitude'] for r in results]
    print(f"Error magnitude range: [{min(magnitudes):.4f}, {max(magnitudes):.4f}]")
    print(f"Variation ratio: {max(magnitudes)/min(magnitudes):.2f}x")
    
    return results, circle_points, quant_weights


# ============================================================
# Plotting functions
# ============================================================

def plot_experiment_1_2(stats1, stats2, scale):
    """Plot comparison of Exp 1 and 2"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Row 1: Experiment 1
    ax = axes[0, 0]
    draw_polygon(ax, stats1.cumulative_error_vertices, COLORS['error_region'], alpha=0.4)
    set_fixed_scale(ax, scale)
    ax.set_title(f"Exp 1: Uniform Diagonal\nVolume: {stats1.cumulative_error_volume:.6f}")
    ax.set_xlabel('Dim 0')
    ax.set_ylabel('Dim 1')
    
    # Per-layer contributions (Exp 1)
    ax = axes[0, 1]
    layers = [ls.layer_idx + 1 for ls in stats1.layer_stats]
    hw0 = [ls.error_half_widths[0] for ls in stats1.layer_stats]
    hw1 = [ls.error_half_widths[1] for ls in stats1.layer_stats]
    x_pos = np.arange(len(layers))
    ax.bar(x_pos - 0.2, hw0, 0.4, label='Dim 0', color=COLORS['layer1'])
    ax.bar(x_pos + 0.2, hw1, 0.4, label='Dim 1', color=COLORS['layer2'])
    ax.set_xticks(x_pos)
    ax.set_xticklabels([f'L{l}' for l in layers])
    ax.set_ylabel('Error half-width')
    ax.set_title('Exp 1: Per-layer error')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Relative error (Exp 1)
    ax = axes[0, 2]
    ax.bar(['Dim 0', 'Dim 1'], stats1.relative_error * 100, color=[COLORS['layer1'], COLORS['layer2']])
    ax.set_ylabel('Relative error (%)')
    ax.set_title('Exp 1: Relative error\n(Should be equal for uniform weights)')
    ax.grid(True, alpha=0.3)
    
    # Row 2: Experiment 2
    ax = axes[1, 0]
    draw_polygon(ax, stats2.cumulative_error_vertices, COLORS['error_region'], alpha=0.4)
    set_fixed_scale(ax, scale)
    ax.set_title(f"Exp 2: Non-Uniform Diagonal\nVolume: {stats2.cumulative_error_volume:.6f}")
    ax.set_xlabel('Dim 0')
    ax.set_ylabel('Dim 1')
    
    # Per-layer contributions (Exp 2)
    ax = axes[1, 1]
    hw0 = [ls.error_half_widths[0] for ls in stats2.layer_stats]
    hw1 = [ls.error_half_widths[1] for ls in stats2.layer_stats]
    ax.bar(x_pos - 0.2, hw0, 0.4, label='Dim 0', color=COLORS['layer1'])
    ax.bar(x_pos + 0.2, hw1, 0.4, label='Dim 1', color=COLORS['layer2'])
    ax.set_xticks(x_pos)
    ax.set_xticklabels([f'L{l}' for l in layers])
    ax.set_ylabel('Error half-width')
    ax.set_title('Exp 2: Per-layer error')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Relative error (Exp 2)
    ax = axes[1, 2]
    ax.bar(['Dim 0', 'Dim 1'], stats2.relative_error * 100, color=[COLORS['layer1'], COLORS['layer2']])
    ax.set_ylabel('Relative error (%)')
    ax.set_title('Exp 2: Relative error\n(Now DIFFERENT per channel)')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('plots/exp1_2_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_experiment_3(stats3, stats1, scale):
    """Plot Experiment 3 results"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Error region comparison
    ax = axes[0]
    draw_polygon(ax, stats1.cumulative_error_vertices, COLORS['layer1'], alpha=0.3, label='Exp1 (diagonal)')
    draw_polygon(ax, stats3.cumulative_error_vertices, COLORS['error_region'], alpha=0.4, label='Exp3 (full)')
    set_fixed_scale(ax, scale)
    ax.set_title('Error region comparison\nBlue=Diagonal, Red=Full matrices')
    ax.set_xlabel('Dim 0')
    ax.set_ylabel('Dim 1')
    ax.legend()
    
    # Shape analysis via SVD
    ax = axes[1]
    centered = stats3.cumulative_error_vertices - stats3.cumulative_error_vertices.mean(axis=0)
    U, S, Vt = np.linalg.svd(centered, full_matrices=False)
    
    ax.bar(['PC1', 'PC2'], S, color=[COLORS['layer1'], COLORS['layer2']])
    ax.set_ylabel('Singular value')
    ax.set_title(f'Error region shape (SVD)\nCondition: {S[0]/S[1]:.2f}')
    ax.grid(True, alpha=0.3)
    
    # Principal directions
    ax = axes[2]
    center = stats3.cumulative_error_vertices.mean(axis=0)
    
    draw_polygon(ax, stats3.cumulative_error_vertices, COLORS['error_region'], alpha=0.3)
    
    # Draw principal axes
    for i, (s, v) in enumerate(zip(S, Vt)):
        ax.arrow(center[0], center[1], v[0]*s*0.8, v[1]*s*0.8,
                head_width=scale*0.03, color=['blue', 'green'][i], linewidth=2,
                label=f'PC{i+1}: [{v[0]:.2f}, {v[1]:.2f}]')
    
    set_fixed_scale(ax, scale)
    ax.set_title('Principal directions\n(Error is anisotropic)')
    ax.set_xlabel('Dim 0')
    ax.set_ylabel('Dim 1')
    ax.legend(loc='upper left', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('plots/exp3_full_matrices.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_experiment_4(results, circle_points, scale):
    """Plot Experiment 4 results"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 14))
    
    # Error magnitude around the circle
    ax = axes[0, 0]
    magnitudes = [r['error_magnitude'] for r in results]
    scatter = ax.scatter(circle_points[:, 0], circle_points[:, 1],
                        c=magnitudes, cmap='hot', s=100, edgecolors='black')
    plt.colorbar(scatter, ax=ax, label='Error magnitude')
    ax.plot(circle_points[:, 0], circle_points[:, 1], 'b-', alpha=0.3, linewidth=1)
    ax.set_xlabel('Input dim 0')
    ax.set_ylabel('Input dim 1')
    ax.set_title('Circle manifold colored by error')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    
    # Error magnitude vs angle
    ax = axes[0, 1]
    angles = np.arctan2(circle_points[:, 1], circle_points[:, 0])
    ax.plot(np.degrees(angles), magnitudes, 'o-', linewidth=2, markersize=6)
    ax.set_xlabel('Angle (degrees)')
    ax.set_ylabel('Error magnitude')
    ax.set_title('Error varies with direction\n(Not constant around circle!)')
    ax.grid(True, alpha=0.3)
    
    # Selected error regions
    ax = axes[1, 0]
    n_selected = 8
    indices = np.linspace(0, len(results)-1, n_selected, dtype=int)
    colors_selected = plt.cm.hsv(np.linspace(0, 1, n_selected))
    
    for idx, color in zip(indices, colors_selected):
        r = results[idx]
        draw_polygon(ax, r['error_vertices'], color, alpha=0.3, linewidth=1)
    
    set_fixed_scale(ax, scale)
    ax.set_title('Error regions for 8 points around circle\n(Shape varies with input direction)')
    ax.set_xlabel('Dim 0')
    ax.set_ylabel('Dim 1')
    
    # Manifold with error regions overlaid
    ax = axes[1, 1]
    error_scale = 0.5  # Scale factor to visualize error regions at each point
    
    for r in results[::2]:  # Every other point for clarity
        vertices = r['error_vertices'] * error_scale + r['input']
        draw_polygon(ax, vertices, COLORS['error_region'], alpha=0.2, linewidth=0.5)
    
    ax.plot(circle_points[:, 0], circle_points[:, 1], 'b-', linewidth=2, label='Input manifold')
    ax.scatter(circle_points[:, 0], circle_points[:, 1], c='blue', s=20, zorder=5)
    ax.set_xlabel('Input dim 0')
    ax.set_ylabel('Input dim 1')
    ax.set_title('Input manifold with error "tubes"\n(Red = error region at each point)')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    plt.tight_layout()
    plt.savefig('plots/exp4_error_manifold.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_summary(all_stats):
    """Summary comparison plot"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    exp_names = list(all_stats.experiments.keys())[:3]  # First 3 experiments
    
    # Volume comparison
    ax = axes[0]
    volumes = [all_stats.experiments[name].cumulative_error_volume for name in exp_names]
    ax.bar(range(len(volumes)), volumes, color=[COLORS['layer1'], COLORS['layer2'], COLORS['layer3']])
    ax.set_xticks(range(len(volumes)))
    ax.set_xticklabels(['Uniform\nDiagonal', 'Non-Uniform\nDiagonal', 'Full\nMatrices'], fontsize=9)
    ax.set_ylabel('Error region volume (area)')
    ax.set_title('Total error volume comparison')
    ax.grid(True, alpha=0.3)
    
    # Overlay all error regions
    ax = axes[1]
    colors = [COLORS['layer1'], COLORS['layer2'], COLORS['layer3']]
    alphas = [0.4, 0.3, 0.2]
    
    for name, color, alpha in zip(exp_names, colors, alphas):
        stats = all_stats.experiments[name]
        draw_polygon(ax, stats.cumulative_error_vertices, color, alpha=alpha, 
                    label=name.split(':')[1].strip())
    
    # Find max scale
    max_extent = 0
    for name in exp_names:
        stats = all_stats.experiments[name]
        max_extent = max(max_extent, np.abs(stats.cumulative_error_vertices).max())
    
    set_fixed_scale(ax, max_extent * 1.2)
    ax.set_title('All error regions overlaid')
    ax.set_xlabel('Dim 0')
    ax.set_ylabel('Dim 1')
    ax.legend(loc='upper left', fontsize=8)
    
    # Spectral norms
    ax = axes[2]
    x_pos = np.arange(N_LAYERS)
    width = 0.25
    
    for i, (name, color) in enumerate(zip(exp_names, colors)):
        stats = all_stats.experiments[name]
        norms = [ls.spectral_norm for ls in stats.layer_stats]
        ax.bar(x_pos + i*width, norms, width, label=name.split(':')[1].strip(), color=color, alpha=0.7)
    
    ax.set_xticks(x_pos + width)
    ax.set_xticklabels([f'L{i+1}' for i in range(N_LAYERS)])
    ax.set_ylabel('Spectral norm')
    ax.set_title('Weight spectral norms by layer')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('plots/all_experiments_summary.png', dpi=150, bbox_inches='tight')
    plt.show()


# ============================================================
# Main execution
# ============================================================

if __name__ == "__main__":
    # Common input point
    x_input = np.array([10.0, 20.0])
    
    print("=" * 70)
    print("QUANTIZATION ERROR GEOMETRY EXPERIMENTS")
    print("=" * 70)
    print(f"Input point: {x_input}")
    print(f"Quantization bits: {BITS}")
    print(f"Delta: {DELTA}")
    
    # Run experiments
    stats1, weights1 = run_experiment_1(x_input)
    stats2, weights2 = run_experiment_2(x_input)
    stats3, weights3 = run_experiment_3(x_input)
    
    # Determine global scale for consistent plotting
    all_vertices = [
        stats1.cumulative_error_vertices,
        stats2.cumulative_error_vertices,
        stats3.cumulative_error_vertices
    ]
    GLOBAL_ERROR_SCALE = max(np.abs(v).max() for v in all_vertices) * 1.3
    print(f"\nGlobal error scale for plots: {GLOBAL_ERROR_SCALE:.4f}")
    
    # Experiment 4 uses weights from Exp 3
    exp3_weights = [
        np.array([[0.9, 0.2], [0.1, 1.0]]),
        np.array([[0.95, -0.15], [0.2, 0.85]]),
        np.array([[1.0, 0.1], [-0.1, 0.9]]),
        np.array([[0.85, 0.15], [0.1, 1.05]]),
    ]
    results4, circle_points, weights4 = run_experiment_4(exp3_weights)
    
    # Plotting
    print("\n" + "=" * 70)
    print("GENERATING PLOTS")
    print("=" * 70)
    
    plot_experiment_1_2(stats1, stats2, GLOBAL_ERROR_SCALE)
    plot_experiment_3(stats3, stats1, GLOBAL_ERROR_SCALE)
    plot_experiment_4(results4, circle_points, GLOBAL_ERROR_SCALE)
    plot_summary(ALL_STATS)
    
    # Print summary
    ALL_STATS.print_summary()
    
    print("\n" + "=" * 70)
    print("KEY TAKEAWAYS")
    print("=" * 70)
    print("""
1. UNIFORM DIAGONAL: Error scales equally for all channels (relative error constant)
   
2. NON-UNIFORM DIAGONAL: Different channels accumulate error at different rates
   - Channels with larger cumulative weights have smaller relative error
   
3. FULL MATRICES: Error region becomes tilted/sheared
   - Can't analyze channels independently
   - Bounding box overestimates true error region
   - Principal components reveal error anisotropy
   
4. ERROR MANIFOLD: Error magnitude and shape vary with input position
   - Further from origin = larger absolute error  
   - Direction matters too (not just magnitude)
   - The "error tube" around a manifold has varying thickness
    """)

In [None]:
# ============================================================
# Manifold definitions
# ============================================================

def make_manifold(name, n_points=32, **kwargs):
    """
    Generate points on various 2D manifolds.
    
    Returns:
        points: (n_points, 2) array
        metadata: dict with manifold properties
    """
    
    if name == "circle":
        radius = kwargs.get('radius', 20)
        theta = np.linspace(0, 2*np.pi, n_points, endpoint=False)
        points = np.column_stack([radius * np.cos(theta), radius * np.sin(theta)])
        metadata = {'radius': radius, 'type': 'closed'}
        
    elif name == "ellipse":
        a = kwargs.get('a', 25)  # Semi-major axis
        b = kwargs.get('b', 10)  # Semi-minor axis
        theta = np.linspace(0, 2*np.pi, n_points, endpoint=False)
        points = np.column_stack([a * np.cos(theta), b * np.sin(theta)])
        metadata = {'a': a, 'b': b, 'type': 'closed'}
        
    elif name == "line":
        start = np.array(kwargs.get('start', [-25, -10]))
        end = np.array(kwargs.get('end', [25, 10]))
        t = np.linspace(0, 1, n_points)
        points = start + t[:, np.newaxis] * (end - start)
        metadata = {'start': start, 'end': end, 'type': 'open'}
        
    elif name == "spiral":
        turns = kwargs.get('turns', 2)
        r_min = kwargs.get('r_min', 5)
        r_max = kwargs.get('r_max', 25)
        theta = np.linspace(0, turns * 2 * np.pi, n_points)
        r = np.linspace(r_min, r_max, n_points)
        points = np.column_stack([r * np.cos(theta), r * np.sin(theta)])
        metadata = {'turns': turns, 'r_min': r_min, 'r_max': r_max, 'type': 'open'}
        
    elif name == "figure_eight":
        scale = kwargs.get('scale', 15)
        t = np.linspace(0, 2*np.pi, n_points, endpoint=False)
        points = np.column_stack([scale * np.sin(t), scale * np.sin(t) * np.cos(t)])
        metadata = {'scale': scale, 'type': 'closed'}
        
    elif name == "grid":
        extent = kwargs.get('extent', 25)
        n_side = int(np.sqrt(n_points))
        x = np.linspace(-extent, extent, n_side)
        y = np.linspace(-extent, extent, n_side)
        xx, yy = np.meshgrid(x, y)
        points = np.column_stack([xx.ravel(), yy.ravel()])
        metadata = {'extent': extent, 'n_side': n_side, 'type': 'area'}
        
    elif name == "two_blobs":
        n_each = n_points // 2
        center1 = np.array(kwargs.get('center1', [-15, 0]))
        center2 = np.array(kwargs.get('center2', [15, 0]))
        std = kwargs.get('std', 5)
        blob1 = np.random.randn(n_each, 2) * std + center1
        blob2 = np.random.randn(n_points - n_each, 2) * std + center2
        points = np.vstack([blob1, blob2])
        metadata = {'center1': center1, 'center2': center2, 'std': std, 'type': 'clusters'}
        
    elif name == "crescent":
        outer_r = kwargs.get('outer_r', 25)
        inner_r = kwargs.get('inner_r', 15)
        theta = np.linspace(0, np.pi, n_points)
        outer = np.column_stack([outer_r * np.cos(theta), outer_r * np.sin(theta)])
        # Offset inner arc
        inner = np.column_stack([inner_r * np.cos(theta) + 5, inner_r * np.sin(theta) - 3])
        points = outer  # Just use outer arc for simplicity
        metadata = {'outer_r': outer_r, 'type': 'open'}
        
    else:
        raise ValueError(f"Unknown manifold: {name}")
    
    return points, metadata


# ============================================================
# Run experiment 4 across all manifolds
# ============================================================

def compute_manifold_errors(points, quant_weights, delta=DELTA):
    """Compute error statistics for all points on a manifold."""
    results = []
    
    for x in points:
        val = x.copy()
        cumulative_W = np.eye(2)
        cumulative_error_vertices = None
        
        for W in quant_weights:
            # Local error box
            l1_norm = np.sum(np.abs(val))
            hw = (delta / 2) * l1_norm
            local_vertices = get_box_vertices_2d([hw, hw])
            
            # Map to input space
            cumulative_W_after = W @ cumulative_W
            try:
                inv_W = np.linalg.inv(cumulative_W_after)
                error_vertices_input = transform_vertices(local_vertices, inv_W)
            except:
                error_vertices_input = local_vertices
            
            # Minkowski sum
            if cumulative_error_vertices is None:
                cumulative_error_vertices = error_vertices_input
            else:
                cumulative_error_vertices = minkowski_sum_2d(
                    cumulative_error_vertices, error_vertices_input
                )
            
            val = W @ val
            cumulative_W = cumulative_W_after
        
        error_magnitude = np.max(np.linalg.norm(cumulative_error_vertices, axis=1))
        error_volume = compute_polygon_area(cumulative_error_vertices)
        
        results.append({
            'input': x.copy(),
            'error_vertices': cumulative_error_vertices.copy(),
            'error_magnitude': error_magnitude,
            'error_volume': error_volume,
            'l1_norm': np.sum(np.abs(x)),
            'l2_norm': np.linalg.norm(x)
        })
    
    return results


def run_all_manifolds(quant_weights, manifold_names=None, n_points=32):
    """Run error analysis across multiple manifolds."""
    
    if manifold_names is None:
        manifold_names = ['circle', 'ellipse', 'line', 'spiral', 'figure_eight', 'two_blobs']
    
    all_results = {}
    
    for name in manifold_names:
        print(f"Processing manifold: {name}")
        points, metadata = make_manifold(name, n_points=n_points)
        results = compute_manifold_errors(points, quant_weights)
        
        # Aggregate statistics
        magnitudes = [r['error_magnitude'] for r in results]
        volumes = [r['error_volume'] for r in results]
        l1_norms = [r['l1_norm'] for r in results]
        
        all_results[name] = {
            'points': points,
            'metadata': metadata,
            'results': results,
            'stats': {
                'error_mag_min': np.min(magnitudes),
                'error_mag_max': np.max(magnitudes),
                'error_mag_mean': np.mean(magnitudes),
                'error_mag_std': np.std(magnitudes),
                'error_vol_mean': np.mean(volumes),
                'variation_ratio': np.max(magnitudes) / np.min(magnitudes),
                'correlation_l1': np.corrcoef(l1_norms, magnitudes)[0, 1]
            }
        }
    
    return all_results


# ============================================================
# Visualization for multiple manifolds
# ============================================================

def plot_manifold_comparison(all_results, scale=None):
    """Compare error patterns across manifolds."""
    
    n_manifolds = len(all_results)
    n_cols = 3
    n_rows = (n_manifolds + n_cols - 1) // n_cols
    
    # Determine global scale if not provided
    if scale is None:
        max_input = 0
        for name, data in all_results.items():
            max_input = max(max_input, np.abs(data['points']).max())
        scale = max_input * 1.4
    
    # Figure 1: Manifolds colored by error
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
    axes = axes.flatten()
    
    for idx, (name, data) in enumerate(all_results.items()):
        ax = axes[idx]
        points = data['points']
        magnitudes = [r['error_magnitude'] for r in data['results']]
        
        scatter = ax.scatter(points[:, 0], points[:, 1], 
                            c=magnitudes, cmap='hot', s=60, edgecolors='black', linewidth=0.5)
        plt.colorbar(scatter, ax=ax, label='Error mag')
        
        # Connect points if closed manifold
        if data['metadata']['type'] == 'closed':
            closed_points = np.vstack([points, points[0]])
            ax.plot(closed_points[:, 0], closed_points[:, 1], 'b-', alpha=0.3, linewidth=1)
        elif data['metadata']['type'] == 'open':
            ax.plot(points[:, 0], points[:, 1], 'b-', alpha=0.3, linewidth=1)
        
        stats = data['stats']
        ax.set_title(f"{name}\nVar ratio: {stats['variation_ratio']:.2f}x, "
                    f"Corr(L1): {stats['correlation_l1']:.2f}")
        set_fixed_scale(ax, scale)
        ax.set_xlabel('Dim 0')
        ax.set_ylabel('Dim 1')
    
    # Hide unused axes
    for idx in range(len(all_results), len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    plt.savefig('plots/manifolds_error_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Figure 2: Summary statistics
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    names = list(all_results.keys())
    
    # Error magnitude range
    ax = axes[0]
    mins = [all_results[n]['stats']['error_mag_min'] for n in names]
    maxs = [all_results[n]['stats']['error_mag_max'] for n in names]
    means = [all_results[n]['stats']['error_mag_mean'] for n in names]
    
    x_pos = np.arange(len(names))
    ax.bar(x_pos, maxs, alpha=0.3, color='red', label='Max')
    ax.bar(x_pos, means, alpha=0.5, color='blue', label='Mean')
    ax.bar(x_pos, mins, alpha=0.7, color='green', label='Min')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Error magnitude')
    ax.set_title('Error range by manifold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Variation ratio
    ax = axes[1]
    ratios = [all_results[n]['stats']['variation_ratio'] for n in names]
    ax.bar(x_pos, ratios, color='purple', alpha=0.7)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Max/Min error ratio')
    ax.set_title('Error variation within manifold\n(Higher = more position-dependent)')
    ax.axhline(1.0, color='gray', linestyle='--')
    ax.grid(True, alpha=0.3)
    
    # Correlation with L1 norm
    ax = axes[2]
    corrs = [all_results[n]['stats']['correlation_l1'] for n in names]
    colors = ['green' if c > 0.8 else 'orange' if c > 0.5 else 'red' for c in corrs]
    ax.bar(x_pos, corrs, color=colors, alpha=0.7)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Correlation')
    ax.set_title('Correlation: Error vs L1 norm\n(Green=predictable, Red=complex)')
    ax.set_ylim(0, 1.1)
    ax.axhline(1.0, color='gray', linestyle='--')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('plots/manifolds_statistics.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return


def plot_manifold_error_regions(all_results, manifold_name, n_show=8, scale=None):
    """Show error regions for selected points on a specific manifold."""
    
    data = all_results[manifold_name]
    points = data['points']
    results = data['results']
    
    if scale is None:
        scale = np.max([r['error_magnitude'] for r in results]) * 1.5
    
    # Select evenly spaced points
    indices = np.linspace(0, len(points)-1, n_show, dtype=int)
    
    fig, axes = plt.subplots(2, n_show//2, figsize=(4*n_show//2, 8))
    axes = axes.flatten()
    
    for ax, idx in zip(axes, indices):
        r = results[idx]
        vertices = r['error_vertices']
        
        draw_polygon(ax, vertices, COLORS['error_region'], alpha=0.4)
        ax.scatter([0], [0], c='black', s=50, zorder=5)
        
        ax.set_title(f"Input: ({r['input'][0]:.1f}, {r['input'][1]:.1f})\n"
                    f"Error: {r['error_magnitude']:.4f}")
        set_fixed_scale(ax, scale)
    
    plt.suptitle(f"Error regions on {manifold_name} manifold", fontsize=14)
    plt.tight_layout()
    plt.savefig(f'plots/manifold_{manifold_name}_regions.png', dpi=150, bbox_inches='tight')
    plt.show()


# ============================================================
# Main execution with manifolds
# ============================================================

if __name__ == "__main__":
    # Use weights from experiment 3
    base_weights = [
        np.array([[0.9, 0.2], [0.1, 1.0]]),
        np.array([[0.95, -0.15], [0.2, 0.85]]),
        np.array([[1.0, 0.1], [-0.1, 0.9]]),
        np.array([[0.85, 0.15], [0.1, 1.05]]),
    ]
    quant_weights = [quantize(W) for W in base_weights]
    
    print("=" * 70)
    print("MANIFOLD COMPARISON")
    print("=" * 70)
    
    # Run all manifolds
    all_results = run_all_manifolds(
        quant_weights,
        manifold_names=['circle', 'ellipse', 'line', 'spiral', 'figure_eight', 'two_blobs'],
        n_points=48
    )
    
    # Print summary
    print("\n" + "-" * 70)
    print("MANIFOLD STATISTICS")
    print("-" * 70)
    print(f"{'Manifold':<15} {'Min Error':<12} {'Max Error':<12} {'Var Ratio':<12} {'Corr(L1)':<10}")
    print("-" * 70)
    for name, data in all_results.items():
        s = data['stats']
        print(f"{name:<15} {s['error_mag_min']:<12.4f} {s['error_mag_max']:<12.4f} "
              f"{s['variation_ratio']:<12.2f} {s['correlation_l1']:<10.2f}")
    
    # Plot
    plot_manifold_comparison(all_results)
    
    # Show detailed error regions for spiral (interesting case)
    plot_manifold_error_regions(all_results, 'spiral', n_show=8)
    
    print("\n" + "=" * 70)
    print("OBSERVATIONS")
    print("=" * 70)
    print("""
    1. CIRCLE: Constant radius but varying error → direction matters
    
    2. ELLIPSE: Similar to circle but stretched → error varies more
    
    3. LINE: Error grows with distance from origin → clear L1 relationship
    
    4. SPIRAL: Combines radial and angular effects → complex pattern
    
    5. FIGURE EIGHT: Self-intersecting → error high at extremes
    
    6. TWO BLOBS: Shows how clusters at different positions have different error
    
    KEY INSIGHT: Error depends on BOTH magnitude and direction of input.
    The correlation with L1 norm tells you how "predictable" the error is.
    Low correlation means direction effects dominate.
    """)