In [None]:
# Experiment 1: Uniform Diagonal Weights
# ======================================
# 
# Setup:
# - 3×1 input "image": [10, 20, 30]
# - Each layer is a diagonal matrix with identical entries: w·I
# - No activations
# - Track error shapes mapped back to input space

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from itertools import product

# ============================================================
# Setup
# ============================================================

np.random.seed(42)

# Input
x_input = np.array([10.0, 20.0, 30.0])

# Quantization parameters
bits = 8
delta = 1.0 / (2 ** (bits - 1))  # Quantization step for weights

# Network: 4 layers, each with uniform diagonal weights
# True weights (before quantization)
true_weights = [0.8, 1.2, 0.9, 1.1]

def quantize(w, delta):
    """Quantize a weight to nearest grid point"""
    return np.round(w / delta) * delta

# Quantized weights
quant_weights = [quantize(w, delta) for w in true_weights]
weight_errors = [qw - tw for qw, tw in zip(quant_weights, true_weights)]

print("Layer configurations:")
print("-" * 50)
for i, (tw, qw, we) in enumerate(zip(true_weights, quant_weights, weight_errors)):
    print(f"Layer {i+1}: true={tw:.4f}, quantized={qw:.6f}, error={we:.6f}")


# ============================================================
# Core computation: trace values and errors through layers
# ============================================================

def trace_through_network(x, true_weights, quant_weights):
    """
    Trace a value through the network, computing:
    - The quantized output at each layer
    - The error introduced at each layer
    - The cumulative error
    """
    history = [{
        'layer': 0,
        'value': x.copy(),
        'error_this_layer': np.zeros_like(x),
        'cumulative_error': np.zeros_like(x),
        'cumulative_error_in_input_space': np.zeros_like(x)
    }]
    
    val_quant = x.copy()  # Value in quantized network
    val_true = x.copy()   # Value in true (FP) network
    
    # For mapping back to input space, track cumulative weight product
    cumulative_weight_product = 1.0
    
    for i, (w_true, w_quant) in enumerate(zip(true_weights, quant_weights)):
        # True network
        val_true = w_true * val_true
        
        # Quantized network
        val_quant_new = w_quant * val_quant
        
        # Error introduced THIS layer: (w_quant - w_true) * val_quant
        # Note: we use val_quant because that's what the quantized network sees
        w_error = w_quant - w_true
        error_this_layer = w_error * val_quant
        
        # Update
        val_quant = val_quant_new
        
        # Cumulative error in output space
        cumulative_error = val_quant - val_true
        
        # Map error back to input space
        # The error at layer i, mapped to input space, is divided by product of all weights up to i
        cumulative_weight_product *= w_quant
        error_in_input_space = cumulative_error / cumulative_weight_product
        
        history.append({
            'layer': i + 1,
            'value_quant': val_quant.copy(),
            'value_true': val_true.copy(),
            'error_this_layer': error_this_layer.copy(),
            'cumulative_error': cumulative_error.copy(),
            'cumulative_error_in_input_space': error_in_input_space.copy(),
            'weight_true': w_true,
            'weight_quant': w_quant,
            'cumulative_weight_product': cumulative_weight_product
        })
    
    return history


history = trace_through_network(x_input, true_weights, quant_weights)

print("\n\nValue and error propagation:")
print("-" * 80)
print(f"{'Layer':<8} {'Value (quant)':<25} {'Cumulative Error':<25} {'Error in Input Space':<25}")
print("-" * 80)
for h in history:
    layer = h['layer']
    if layer == 0:
        val_str = str(h['value'])
        err_str = "[0, 0, 0]"
        err_input_str = "[0, 0, 0]"
    else:
        val_str = f"[{h['value_quant'][0]:.3f}, {h['value_quant'][1]:.3f}, {h['value_quant'][2]:.3f}]"
        err_str = f"[{h['cumulative_error'][0]:.4f}, {h['cumulative_error'][1]:.4f}, {h['cumulative_error'][2]:.4f}]"
        err_input_str = f"[{h['cumulative_error_in_input_space'][0]:.4f}, {h['cumulative_error_in_input_space'][1]:.4f}, {h['cumulative_error_in_input_space'][2]:.4f}]"
    print(f"{layer:<8} {val_str:<25} {err_str:<25} {err_input_str:<25}")


# ============================================================
# Compute error REGIONS (not just single errors)
# ============================================================

def compute_error_box_at_layer(x_at_layer, delta):
    """
    The error from quantizing weights at this layer.
    
    Each weight has error in [-delta/2, +delta/2].
    The output error is weight_error * x, so it's in:
    [-delta/2 * |x|, +delta/2 * |x|] for each dimension.
    
    Returns the half-widths of the error box.
    """
    return (delta / 2) * np.abs(x_at_layer)


def compute_error_boxes_through_network(x, quant_weights, delta):
    """
    Compute the error box contributed by each layer,
    all mapped back to input space.
    
    Returns list of half-widths for each layer's contribution.
    """
    boxes = []
    
    val = x.copy()
    cumulative_weight = 1.0
    
    for i, w in enumerate(quant_weights):
        # Error box at this layer (in layer's output space)
        box_at_layer = compute_error_box_at_layer(val, delta)
        
        # Map back to input space: divide by cumulative weight so far
        # (The current layer's error is in output space of this layer,
        # which is input space scaled by cumulative_weight * w)
        cumulative_weight_after = cumulative_weight * w
        box_in_input_space = box_at_layer / cumulative_weight_after
        
        boxes.append({
            'layer': i + 1,
            'box_half_widths_output_space': box_at_layer.copy(),
            'box_half_widths_input_space': box_in_input_space.copy(),
            'value_at_layer_input': val.copy(),
            'cumulative_weight': cumulative_weight_after
        })
        
        # Update value for next layer
        val = w * val
        cumulative_weight = cumulative_weight_after
    
    return boxes


boxes = compute_error_boxes_through_network(x_input, quant_weights, delta)

print("\n\nError boxes per layer (mapped to input space):")
print("-" * 70)
print(f"{'Layer':<8} {'Box half-widths (input space)':<40} {'Value seen by layer':<20}")
print("-" * 70)
for b in boxes:
    hw = b['box_half_widths_input_space']
    val = b['value_at_layer_input']
    print(f"{b['layer']:<8} [{hw[0]:.6f}, {hw[1]:.6f}, {hw[2]:.6f}]    [{val[0]:.2f}, {val[1]:.2f}, {val[2]:.2f}]")

# Minkowski sum: for axis-aligned boxes, just add the half-widths
total_box_half_widths = np.zeros(3)
for b in boxes:
    total_box_half_widths += b['box_half_widths_input_space']

print(f"\nTotal error box (Minkowski sum): [{total_box_half_widths[0]:.6f}, {total_box_half_widths[1]:.6f}, {total_box_half_widths[2]:.6f}]")
print(f"Total error range per channel:")
for i, (hw, x) in enumerate(zip(total_box_half_widths, x_input)):
    print(f"  Channel {i} (input={x}): error in [{-hw:.6f}, {+hw:.6f}], i.e., ±{100*hw/x:.3f}% of input")


# ============================================================
# Visualization
# ============================================================

def draw_box_3d(ax, center, half_widths, color, alpha=0.3, label=None):
    """Draw a 3D box centered at 'center' with given half-widths."""
    hw = half_widths
    
    # 8 vertices of the box
    vertices = np.array(list(product([-1, 1], repeat=3))) * hw + center
    
    # 6 faces (each face is 4 vertices)
    faces = [
        [vertices[0], vertices[1], vertices[3], vertices[2]],  # bottom
        [vertices[4], vertices[5], vertices[7], vertices[6]],  # top
        [vertices[0], vertices[1], vertices[5], vertices[4]],  # front
        [vertices[2], vertices[3], vertices[7], vertices[6]],  # back
        [vertices[0], vertices[2], vertices[6], vertices[4]],  # left
        [vertices[1], vertices[3], vertices[7], vertices[5]],  # right
    ]
    
    ax.add_collection3d(Poly3DCollection(
        faces, alpha=alpha, facecolor=color, edgecolor='black', linewidth=0.5
    ))
    
    if label:
        ax.text(center[0], center[1], center[2] + hw[2] * 1.2, label, fontsize=10)


# Figure 1: Error boxes per layer in input space
fig = plt.figure(figsize=(16, 5))

# Plot 1: Individual error boxes from each layer
ax1 = fig.add_subplot(131, projection='3d')

colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(boxes)))
offset = 0
for i, (b, color) in enumerate(zip(boxes, colors)):
    center = np.array([offset, offset, offset])  # Offset for visibility
    hw = b['box_half_widths_input_space']
    draw_box_3d(ax1, center, hw, color, alpha=0.5, label=f"L{b['layer']}")
    offset += max(hw) * 2.5

ax1.set_xlabel('Channel 0 (x=10)')
ax1.set_ylabel('Channel 1 (x=20)')
ax1.set_zlabel('Channel 2 (x=30)')
ax1.set_title('Error boxes from each layer\n(separated for visibility)')


# Plot 2: Cumulative Minkowski sum
ax2 = fig.add_subplot(132, projection='3d')

cumulative_hw = np.zeros(3)
for i, (b, color) in enumerate(zip(boxes, colors)):
    cumulative_hw = cumulative_hw + b['box_half_widths_input_space']
    # Draw at same center, growing box
    draw_box_3d(ax2, np.zeros(3), cumulative_hw, color, alpha=0.2)

ax2.set_xlabel('Channel 0 (x=10)')
ax2.set_ylabel('Channel 1 (x=20)')
ax2.set_zlabel('Channel 2 (x=30)')
ax2.set_title('Cumulative error box\n(Minkowski sum, nested)')


# Plot 3: Final error box with proportions
ax3 = fig.add_subplot(133, projection='3d')

draw_box_3d(ax3, np.zeros(3), total_box_half_widths, 'red', alpha=0.4)

# Mark the axes extents
for i, (hw, label) in enumerate(zip(total_box_half_widths, ['x=10', 'x=20', 'x=30'])):
    if i == 0:
        ax3.plot([-hw, hw], [0, 0], [0, 0], 'b-', linewidth=3, label=f'Ch0 ({label}): ±{hw:.4f}')
    elif i == 1:
        ax3.plot([0, 0], [-hw, hw], [0, 0], 'g-', linewidth=3, label=f'Ch1 ({label}): ±{hw:.4f}')
    else:
        ax3.plot([0, 0], [0, 0], [-hw, hw], 'r-', linewidth=3, label=f'Ch2 ({label}): ±{hw:.4f}')

ax3.set_xlabel('Channel 0')
ax3.set_ylabel('Channel 1')
ax3.set_zlabel('Channel 2')
ax3.set_title('Final error box\nNote: larger input → larger error')
ax3.legend(loc='upper left', fontsize=8)

plt.tight_layout()
plt.savefig('plots/experiment1_error_boxes.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 2: Error growth through layers
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Per-layer contribution
ax = axes[0]
layer_nums = [b['layer'] for b in boxes]
for ch in range(3):
    contributions = [b['box_half_widths_input_space'][ch] for b in boxes]
    ax.bar(np.array(layer_nums) + (ch - 1) * 0.25, contributions, width=0.25, 
           label=f'Channel {ch} (input={x_input[ch]:.0f})')
ax.set_xlabel('Layer')
ax.set_ylabel('Error half-width (input space)')
ax.set_title('Error contribution per layer')
ax.legend()
ax.grid(True, alpha=0.3)

# Cumulative error
ax = axes[1]
for ch in range(3):
    cumulative = np.cumsum([b['box_half_widths_input_space'][ch] for b in boxes])
    ax.plot(layer_nums, cumulative, 'o-', linewidth=2, markersize=8, 
            label=f'Channel {ch} (input={x_input[ch]:.0f})')
ax.set_xlabel('Layer')
ax.set_ylabel('Cumulative error half-width')
ax.set_title('Cumulative error (Minkowski sum)')
ax.legend()
ax.grid(True, alpha=0.3)

# Error as percentage of input
ax = axes[2]
for ch in range(3):
    cumulative = np.cumsum([b['box_half_widths_input_space'][ch] for b in boxes])
    percentage = 100 * cumulative / x_input[ch]
    ax.plot(layer_nums, percentage, 'o-', linewidth=2, markersize=8,
            label=f'Channel {ch}')
ax.set_xlabel('Layer')
ax.set_ylabel('Error as % of input value')
ax.set_title('Relative error\n(Note: same % for all channels!)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/experiment1_error_growth.png', dpi=150, bbox_inches='tight')
plt.show()


# ============================================================
# Key observations
# ============================================================

print("\n" + "=" * 70)
print("KEY OBSERVATIONS")
print("=" * 70)

print("""
1. ERROR SCALES WITH INPUT MAGNITUDE
   - Channel 2 (input=30) has 3x the absolute error of Channel 0 (input=10)
   - But the RELATIVE error (as % of input) is the same for all channels
   - This is because we have uniform diagonal weights

2. ERROR ACCUMULATES THROUGH LAYERS
   - Each layer contributes its own error box
   - Total error is the Minkowski sum (for axis-aligned boxes, just add half-widths)
   - With 4 layers, error is ~4x a single layer (roughly, depends on weight magnitudes)

3. WEIGHT MAGNITUDE MATTERS
   - Weights > 1 amplify previous errors but also the current layer's error contribution
   - Weights < 1 shrink previous errors
   - The error from layer i, mapped to input space, is divided by the product of weights 1..i

4. THE ERROR BOX IS AXIS-ALIGNED
   - Because weights are diagonal, channels don't mix
   - Error in each channel is independent
   - This will change in Experiment 3 when we use non-diagonal weights
""")

# Verify the relative error is the same
print("\nVerification - relative error per channel:")
for ch in range(3):
    rel_error = total_box_half_widths[ch] / x_input[ch]
    print(f"  Channel {ch}: {100*rel_error:.6f}%")

In [None]:
# Experiment 2: Non-Uniform Diagonal Weights
# ==========================================
# 
# Setup:
# - Same 3×1 input "image": [10, 20, 30]
# - Each layer is a diagonal matrix with DIFFERENT entries per channel
# - No activations
# - Track how different channels accumulate error at different rates

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from itertools import product

# ============================================================
# Setup
# ============================================================

np.random.seed(42)

# Input
x_input = np.array([10.0, 20.0, 30.0])

# Quantization parameters
bits = 8
delta = 1.0 / (2 ** (bits - 1))

# Network: 4 layers, each with NON-UNIFORM diagonal weights
# Now each layer has different weights per channel
# Shape: (n_layers, n_channels)
true_weights = np.array([
    [0.8, 1.2, 0.5],   # Layer 1: channel 1 amplified, channel 2 shrunk
    [1.1, 0.7, 1.3],   # Layer 2: channel 0,2 amplified, channel 1 shrunk
    [0.9, 1.1, 0.9],   # Layer 3: relatively uniform
    [1.2, 0.8, 1.1],   # Layer 4: channel 0 amplified, channel 1 shrunk
])

n_layers, n_channels = true_weights.shape

def quantize(w, delta):
    """Quantize weights to nearest grid point"""
    return np.round(w / delta) * delta

# Quantize all weights
quant_weights = quantize(true_weights, delta)
weight_errors = quant_weights - true_weights

print("Layer configurations:")
print("-" * 70)
print(f"{'Layer':<8} {'Channel 0':<20} {'Channel 1':<20} {'Channel 2':<20}")
print("-" * 70)
for i in range(n_layers):
    ch0 = f"{true_weights[i,0]:.3f}→{quant_weights[i,0]:.3f}"
    ch1 = f"{true_weights[i,1]:.3f}→{quant_weights[i,1]:.3f}"
    ch2 = f"{true_weights[i,2]:.3f}→{quant_weights[i,2]:.3f}"
    print(f"{i+1:<8} {ch0:<20} {ch1:<20} {ch2:<20}")

# Compute cumulative weight products per channel
cumulative_products = np.cumprod(quant_weights, axis=0)
print("\nCumulative weight products (determines total amplification):")
print("-" * 70)
for i in range(n_layers):
    print(f"After layer {i+1}: [{cumulative_products[i,0]:.4f}, {cumulative_products[i,1]:.4f}, {cumulative_products[i,2]:.4f}]")

final_products = cumulative_products[-1]
print(f"\nFinal amplification: Ch0={final_products[0]:.4f}x, Ch1={final_products[1]:.4f}x, Ch2={final_products[2]:.4f}x")


# ============================================================
# Core computation: trace values and errors through layers
# ============================================================

def trace_through_network(x, true_weights, quant_weights):
    """
    Trace a value through the network with per-channel weights.
    """
    n_layers = true_weights.shape[0]
    
    history = [{
        'layer': 0,
        'value': x.copy(),
        'cumulative_error': np.zeros_like(x),
        'cumulative_error_in_input_space': np.zeros_like(x)
    }]
    
    val_quant = x.copy()
    val_true = x.copy()
    cumulative_weight_product = np.ones_like(x)
    
    for i in range(n_layers):
        w_true = true_weights[i]
        w_quant = quant_weights[i]
        
        # Update values (element-wise multiplication for diagonal weights)
        val_true = w_true * val_true
        val_quant_new = w_quant * val_quant
        
        val_quant = val_quant_new
        cumulative_weight_product = cumulative_weight_product * w_quant
        
        # Cumulative error
        cumulative_error = val_quant - val_true
        
        # Map back to input space (per-channel division)
        error_in_input_space = cumulative_error / cumulative_weight_product
        
        history.append({
            'layer': i + 1,
            'value_quant': val_quant.copy(),
            'value_true': val_true.copy(),
            'weight_quant': w_quant.copy(),
            'cumulative_error': cumulative_error.copy(),
            'cumulative_error_in_input_space': error_in_input_space.copy(),
            'cumulative_weight_product': cumulative_weight_product.copy()
        })
    
    return history


history = trace_through_network(x_input, true_weights, quant_weights)

print("\n\nValue and error propagation:")
print("-" * 100)
print(f"{'Layer':<7} {'Value (quant)':<30} {'Cumulative Error':<30} {'Error in Input Space':<30}")
print("-" * 100)
for h in history:
    layer = h['layer']
    if layer == 0:
        val_str = f"[{h['value'][0]:.2f}, {h['value'][1]:.2f}, {h['value'][2]:.2f}]"
        err_str = "[0, 0, 0]"
        err_input_str = "[0, 0, 0]"
    else:
        val_str = f"[{h['value_quant'][0]:.4f}, {h['value_quant'][1]:.4f}, {h['value_quant'][2]:.4f}]"
        err_str = f"[{h['cumulative_error'][0]:.5f}, {h['cumulative_error'][1]:.5f}, {h['cumulative_error'][2]:.5f}]"
        err_input_str = f"[{h['cumulative_error_in_input_space'][0]:.5f}, {h['cumulative_error_in_input_space'][1]:.5f}, {h['cumulative_error_in_input_space'][2]:.5f}]"
    print(f"{layer:<7} {val_str:<30} {err_str:<30} {err_input_str:<30}")


# ============================================================
# Compute error REGIONS per layer
# ============================================================

def compute_error_boxes_through_network(x, quant_weights, delta):
    """
    Compute error boxes with per-channel weights.
    """
    n_layers = quant_weights.shape[0]
    boxes = []
    
    val = x.copy()
    cumulative_weight = np.ones_like(x)
    
    for i in range(n_layers):
        w = quant_weights[i]
        
        # Error box at this layer (in output space)
        # Error = delta_w * val, and delta_w in [-delta/2, delta/2]
        box_at_layer = (delta / 2) * np.abs(val)
        
        # Map to input space: divide by cumulative weight (per channel)
        cumulative_weight_after = cumulative_weight * w
        box_in_input_space = box_at_layer / np.abs(cumulative_weight_after)
        
        boxes.append({
            'layer': i + 1,
            'box_half_widths_output_space': box_at_layer.copy(),
            'box_half_widths_input_space': box_in_input_space.copy(),
            'value_at_layer_input': val.copy(),
            'weight_this_layer': w.copy(),
            'cumulative_weight': cumulative_weight_after.copy()
        })
        
        # Update for next layer
        val = w * val
        cumulative_weight = cumulative_weight_after
    
    return boxes


boxes = compute_error_boxes_through_network(x_input, quant_weights, delta)

print("\n\nError boxes per layer (mapped to input space):")
print("-" * 90)
print(f"{'Layer':<7} {'Half-widths (input space)':<45} {'Weight this layer':<25}")
print("-" * 90)
for b in boxes:
    hw = b['box_half_widths_input_space']
    w = b['weight_this_layer']
    print(f"{b['layer']:<7} [{hw[0]:.6f}, {hw[1]:.6f}, {hw[2]:.6f}]   [{w[0]:.3f}, {w[1]:.3f}, {w[2]:.3f}]")

# Minkowski sum
total_box_half_widths = np.zeros(3)
for b in boxes:
    total_box_half_widths += b['box_half_widths_input_space']

print(f"\nTotal error box (Minkowski sum):")
print(f"  Half-widths: [{total_box_half_widths[0]:.6f}, {total_box_half_widths[1]:.6f}, {total_box_half_widths[2]:.6f}]")
print(f"\nRelative error per channel:")
for i in range(3):
    rel_err = total_box_half_widths[i] / x_input[i]
    print(f"  Channel {i} (input={x_input[i]:.0f}): ±{100*rel_err:.4f}%")


# ============================================================
# Visualization
# ============================================================

def draw_box_3d(ax, center, half_widths, color, alpha=0.3, label=None):
    """Draw a 3D box."""
    hw = np.array(half_widths)
    vertices = np.array(list(product([-1, 1], repeat=3))) * hw + center
    
    faces = [
        [vertices[0], vertices[1], vertices[3], vertices[2]],
        [vertices[4], vertices[5], vertices[7], vertices[6]],
        [vertices[0], vertices[1], vertices[5], vertices[4]],
        [vertices[2], vertices[3], vertices[7], vertices[6]],
        [vertices[0], vertices[2], vertices[6], vertices[4]],
        [vertices[1], vertices[3], vertices[7], vertices[5]],
    ]
    
    ax.add_collection3d(Poly3DCollection(
        faces, alpha=alpha, facecolor=color, edgecolor='black', linewidth=0.5
    ))


# Figure 1: Compare uniform vs non-uniform weights
fig = plt.figure(figsize=(16, 6))

# Recompute uniform case for comparison
uniform_weights = np.ones_like(true_weights) * true_weights.mean()
uniform_quant = quantize(uniform_weights, delta)
boxes_uniform = compute_error_boxes_through_network(x_input, uniform_quant, delta)
total_uniform = sum(b['box_half_widths_input_space'] for b in boxes_uniform)

# Plot 1: Non-uniform error box
ax1 = fig.add_subplot(131, projection='3d')
draw_box_3d(ax1, np.zeros(3), total_box_half_widths, 'red', alpha=0.4)
ax1.set_xlabel('Ch0 (x=10)')
ax1.set_ylabel('Ch1 (x=20)')
ax1.set_zlabel('Ch2 (x=30)')
ax1.set_title('Non-uniform weights\nError box')
# Set equal aspect ratio approximately
max_hw = max(total_box_half_widths)
ax1.set_xlim(-max_hw*1.2, max_hw*1.2)
ax1.set_ylim(-max_hw*1.2, max_hw*1.2)
ax1.set_zlim(-max_hw*1.2, max_hw*1.2)

# Plot 2: Uniform error box (for comparison)
ax2 = fig.add_subplot(132, projection='3d')
draw_box_3d(ax2, np.zeros(3), total_uniform, 'blue', alpha=0.4)
ax2.set_xlabel('Ch0 (x=10)')
ax2.set_ylabel('Ch1 (x=20)')
ax2.set_zlabel('Ch2 (x=30)')
ax2.set_title('Uniform weights (same mean)\nError box')
ax2.set_xlim(-max_hw*1.2, max_hw*1.2)
ax2.set_ylim(-max_hw*1.2, max_hw*1.2)
ax2.set_zlim(-max_hw*1.2, max_hw*1.2)

# Plot 3: Overlay both
ax3 = fig.add_subplot(133, projection='3d')
draw_box_3d(ax3, np.zeros(3), total_box_half_widths, 'red', alpha=0.3)
draw_box_3d(ax3, np.zeros(3), total_uniform, 'blue', alpha=0.3)
ax3.set_xlabel('Ch0 (x=10)')
ax3.set_ylabel('Ch1 (x=20)')
ax3.set_zlabel('Ch2 (x=30)')
ax3.set_title('Overlay\nRed=Non-uniform, Blue=Uniform')
ax3.set_xlim(-max_hw*1.2, max_hw*1.2)
ax3.set_ylim(-max_hw*1.2, max_hw*1.2)
ax3.set_zlim(-max_hw*1.2, max_hw*1.2)

plt.tight_layout()
plt.savefig('plots/experiment2_comparison_boxes.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 2: Per-layer and per-channel analysis
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# Row 1: Error analysis
# Plot 1: Per-layer contribution per channel
ax = axes[0, 0]
layer_nums = [b['layer'] for b in boxes]
width = 0.25
for ch in range(3):
    contributions = [b['box_half_widths_input_space'][ch] for b in boxes]
    ax.bar(np.array(layer_nums) + (ch - 1) * width, contributions, width=width,
           label=f'Ch{ch} (x={x_input[ch]:.0f})')
ax.set_xlabel('Layer')
ax.set_ylabel('Error half-width (input space)')
ax.set_title('Error contribution per layer\n(Non-uniform weights)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Cumulative error per channel
ax = axes[0, 1]
for ch in range(3):
    cumulative = np.cumsum([b['box_half_widths_input_space'][ch] for b in boxes])
    ax.plot(layer_nums, cumulative, 'o-', linewidth=2, markersize=8,
            label=f'Ch{ch} (x={x_input[ch]:.0f})')
ax.set_xlabel('Layer')
ax.set_ylabel('Cumulative error half-width')
ax.set_title('Cumulative error\n(Minkowski sum)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Relative error per channel - THE KEY DIFFERENCE
ax = axes[0, 2]
for ch in range(3):
    cumulative = np.cumsum([b['box_half_widths_input_space'][ch] for b in boxes])
    percentage = 100 * cumulative / x_input[ch]
    ax.plot(layer_nums, percentage, 'o-', linewidth=2, markersize=8,
            label=f'Ch{ch}')
ax.set_xlabel('Layer')
ax.set_ylabel('Error as % of input value')
ax.set_title('Relative error per channel\n(NOW DIFFERENT for each channel!)')
ax.legend()
ax.grid(True, alpha=0.3)

# Row 2: Weight analysis
# Plot 4: Weight values per layer
ax = axes[1, 0]
for ch in range(3):
    weights = quant_weights[:, ch]
    ax.plot(range(1, n_layers+1), weights, 'o-', linewidth=2, markersize=8,
            label=f'Ch{ch}')
ax.axhline(1.0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Layer')
ax.set_ylabel('Weight value')
ax.set_title('Weight per channel per layer\n(>1 amplifies, <1 shrinks)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 5: Cumulative weight product
ax = axes[1, 1]
for ch in range(3):
    cum_prod = np.cumprod(quant_weights[:, ch])
    ax.plot(range(1, n_layers+1), cum_prod, 'o-', linewidth=2, markersize=8,
            label=f'Ch{ch}')
ax.axhline(1.0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Layer')
ax.set_ylabel('Cumulative weight product')
ax.set_title('Total amplification per channel\n(Product of all weights)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 6: Final comparison - relative error vs cumulative weight
ax = axes[1, 2]
final_rel_errors = total_box_half_widths / x_input * 100
final_cum_weights = np.prod(quant_weights, axis=0)

x_pos = np.arange(3)
width = 0.35
bars1 = ax.bar(x_pos - width/2, final_rel_errors, width, label='Relative error (%)', color='red', alpha=0.7)
ax2 = ax.twinx()
bars2 = ax2.bar(x_pos + width/2, final_cum_weights, width, label='Cumulative weight', color='blue', alpha=0.7)

ax.set_xlabel('Channel')
ax.set_ylabel('Relative error (%)', color='red')
ax2.set_ylabel('Cumulative weight product', color='blue')
ax.set_xticks(x_pos)
ax.set_xticklabels(['Ch0 (x=10)', 'Ch1 (x=20)', 'Ch2 (x=30)'])
ax.set_title('Final relative error vs weight product\n(Higher weight product → lower relative error)')
ax.legend(loc='upper left')
ax2.legend(loc='upper right')

plt.tight_layout()
plt.savefig('plots/experiment2_detailed_analysis.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 3: Visualize how the box shape differs from input ratios
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Normalize by input to see the "distortion" from uniform
ax = axes[0]
uniform_relative = total_uniform / x_input
nonuniform_relative = total_box_half_widths / x_input

x_pos = np.arange(3)
width = 0.35
ax.bar(x_pos - width/2, uniform_relative * 100, width, label='Uniform weights', color='blue', alpha=0.7)
ax.bar(x_pos + width/2, nonuniform_relative * 100, width, label='Non-uniform weights', color='red', alpha=0.7)
ax.set_xlabel('Channel')
ax.set_ylabel('Relative error (%)')
ax.set_xticks(x_pos)
ax.set_xticklabels(['Ch0 (x=10)', 'Ch1 (x=20)', 'Ch2 (x=30)'])
ax.set_title('Relative error comparison\nUniform: all same | Non-uniform: different per channel')
ax.legend()
ax.grid(True, alpha=0.3)

# Show the "error sensitivity" - which channels are most affected
ax = axes[1]
# Sensitivity = relative_error / mean_relative_error
mean_rel = nonuniform_relative.mean()
sensitivity = nonuniform_relative / mean_rel

ax.bar(x_pos, sensitivity, color=['green' if s < 1 else 'red' for s in sensitivity], alpha=0.7)
ax.axhline(1.0, color='gray', linestyle='--', linewidth=2, label='Average sensitivity')
ax.set_xlabel('Channel')
ax.set_ylabel('Error sensitivity (relative to mean)')
ax.set_xticks(x_pos)
ax.set_xticklabels(['Ch0 (x=10)', 'Ch1 (x=20)', 'Ch2 (x=30)'])
ax.set_title('Channel error sensitivity\nGreen = below average, Red = above average')
ax.legend()
ax.grid(True, alpha=0.3)

for i, s in enumerate(sensitivity):
    ax.annotate(f'{s:.2f}x', (i, s + 0.05), ha='center', fontsize=12)

plt.tight_layout()
plt.savefig('plots/experiment2_sensitivity.png', dpi=150, bbox_inches='tight')
plt.show()


# ============================================================
# Key observations
# ============================================================

print("\n" + "=" * 70)
print("KEY OBSERVATIONS - EXPERIMENT 2")
print("=" * 70)

print(f"""
1. RELATIVE ERROR NOW VARIES BY CHANNEL
   - Channel 0: {100*total_box_half_widths[0]/x_input[0]:.4f}%
   - Channel 1: {100*total_box_half_widths[1]/x_input[1]:.4f}%
   - Channel 2: {100*total_box_half_widths[2]/x_input[2]:.4f}%
   
   Unlike Experiment 1, these are NOT the same!

2. CUMULATIVE WEIGHT PRODUCT MATTERS
   - Channel with highest weight product: {np.argmax(final_cum_weights)} (product={final_cum_weights.max():.4f})
   - Channel with lowest weight product: {np.argmin(final_cum_weights)} (product={final_cum_weights.min():.4f})
   
   Higher cumulative weight = smaller relative error (error gets "diluted" by amplification)

3. THE ERROR BOX IS NO LONGER PROPORTIONAL TO INPUT
   - In Exp 1: error box ~ [1, 2, 3] (proportional to input [10, 20, 30])
   - In Exp 2: error box is distorted by the non-uniform weights
   - Some channels accumulate more error than their input magnitude would suggest

4. ERROR SENSITIVITY IDENTIFIES VULNERABLE CHANNELS
   - Sensitivity > 1: channel accumulates MORE error than average
   - Sensitivity < 1: channel accumulates LESS error than average
   - This is determined by the weight structure, not just input magnitude

5. IMPLICATIONS FOR QUANTIZATION
   - Non-uniform weights create non-uniform error sensitivity
   - Could potentially allocate more bits to sensitive channels
   - Or design correction layers that focus on high-sensitivity channels
""")

# Verify the relationship between cumulative weight and relative error
print("\nVerification - inverse relationship between weight product and relative error:")
print("-" * 60)
for ch in range(3):
    rel_err = total_box_half_widths[ch] / x_input[ch]
    cum_weight = final_cum_weights[ch]
    print(f"  Channel {ch}: rel_error={rel_err:.6f}, cum_weight={cum_weight:.4f}, product={rel_err * cum_weight:.6f}")

print("\n  (The product rel_error × cum_weight should be similar across channels)")
print("  (It represents the 'raw' error before being diluted by weight amplification)")

In [None]:
# Experiment 3: Non-Diagonal Weights
# ==================================
# 
# Setup:
# - Same 3×1 input "image": [10, 20, 30]
# - Each layer is a FULL matrix (off-diagonal entries)
# - Channels now MIX - error in one channel affects others
# - The error box becomes a PARALLELEPIPED, not an axis-aligned box

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
from scipy.spatial import ConvexHull
from itertools import product

# ============================================================
# Setup
# ============================================================

np.random.seed(42)

# Input
x_input = np.array([10.0, 20.0, 30.0])

# Quantization parameters
bits = 8
delta = 1.0 / (2 ** (bits - 1))

# Network: 4 layers, each with FULL 3x3 matrices
# Design matrices with different characteristics
true_weights = [
    # Layer 1: Mostly diagonal with small off-diagonal (mild mixing)
    np.array([
        [0.9, 0.1, 0.0],
        [0.1, 1.1, 0.1],
        [0.0, 0.1, 0.8]
    ]),
    # Layer 2: Rotation-like (strong mixing)
    np.array([
        [0.8, -0.3, 0.1],
        [0.3, 0.8, -0.2],
        [-0.1, 0.2, 0.9]
    ]),
    # Layer 3: Scaling with shear
    np.array([
        [1.1, 0.2, 0.0],
        [0.0, 0.9, 0.2],
        [0.1, 0.0, 1.0]
    ]),
    # Layer 4: Another rotation-like
    np.array([
        [0.9, 0.2, -0.1],
        [-0.2, 1.0, 0.1],
        [0.1, -0.1, 0.85]
    ]),
]

n_layers = len(true_weights)

def quantize(W, delta):
    """Quantize a matrix to nearest grid points"""
    return np.round(W / delta) * delta

# Quantize all weights
quant_weights = [quantize(W, delta) for W in true_weights]
weight_errors = [Wq - Wt for Wq, Wt in zip(quant_weights, true_weights)]

print("Layer configurations:")
print("=" * 70)
for i, (Wt, Wq) in enumerate(zip(true_weights, quant_weights)):
    print(f"\nLayer {i+1}:")
    print(f"  True weights:\n{Wt}")
    print(f"  Quantized weights:\n{Wq}")
    print(f"  Frobenius norm of error: {np.linalg.norm(Wq - Wt):.6f}")
    print(f"  Spectral norm (max amplification): {np.linalg.norm(Wq, ord=2):.4f}")
    print(f"  Determinant (volume scaling): {np.linalg.det(Wq):.4f}")

# Compute cumulative weight product
cumulative_product = np.eye(3)
print("\n\nCumulative transformation properties:")
print("-" * 60)
for i, W in enumerate(quant_weights):
    cumulative_product = W @ cumulative_product
    U, S, Vt = np.linalg.svd(cumulative_product)
    print(f"After layer {i+1}:")
    print(f"  Singular values: [{S[0]:.4f}, {S[1]:.4f}, {S[2]:.4f}]")
    print(f"  Condition number: {S.max()/S.min():.4f}")
    print(f"  Determinant: {np.linalg.det(cumulative_product):.4f}")


# ============================================================
# Core computation: error regions as parallelepipeds
# ============================================================

def get_hypercube_vertices(half_width, dims=3):
    """Get vertices of a hypercube centered at origin"""
    return np.array(list(product([-1, 1], repeat=dims))) * half_width


def transform_vertices(vertices, W):
    """Apply linear transformation to vertices"""
    return vertices @ W.T


def compute_error_region_vertices(x_at_layer, W, delta):
    """
    Compute vertices of the error region from quantizing weights W.
    
    Each weight W[i,j] has error in [-delta/2, delta/2].
    Output error = W_error @ x
    
    For a 3x3 matrix, W_error has 9 independent error terms.
    The output error region is the image of a 9D hypercube under
    the linear map that takes W_error to W_error @ x.
    
    But we can simplify: the output is 3D, and each output dimension
    is a linear combination of the input weighted by error terms.
    
    output_error[i] = sum_j W_error[i,j] * x[j]
                    = sum_j (in [-delta/2, delta/2]) * x[j]
    
    For each output dimension i, the error is in 
    [-delta/2 * sum|x|, delta/2 * sum|x|]... no wait, that's not right either.
    
    Actually: each output dimension i has error from 3 independent terms:
    output_error[i] = W_err[i,0]*x[0] + W_err[i,1]*x[1] + W_err[i,2]*x[2]
    
    Each W_err[i,j] is independent, in [-delta/2, delta/2].
    So output_error[i] is in a range determined by the Minkowski sum
    of intervals [-delta/2 * |x[j]|, delta/2 * |x[j]|] for j=0,1,2.
    
    For axis-aligned, this would be: [-delta/2 * (|x[0]|+|x[1]|+|x[2]|), ...] per dimension.
    But the three output dimensions are INDEPENDENT (different rows of W).
    
    So the error region is actually an axis-aligned box with half-widths:
    hw[i] = delta/2 * sum_j |x[j]|
    
    Wait, that makes all dimensions the same. Let me reconsider...
    
    Actually each dimension i has:
    hw[i] = delta/2 * sum_j |x[j]|   (since each W_err[i,j] is independent)
    
    So the error from weight quantization at a single layer IS an axis-aligned box,
    but with equal half-widths (proportional to L1 norm of input).
    
    The non-axis-aligned shape comes from TRANSFORMING this box through subsequent layers.
    """
    # Error from weight quantization: axis-aligned box
    # Each output dim i: error = sum_j W_err[i,j] * x[j]
    # Max error per output dim = delta/2 * sum|x|
    l1_norm = np.sum(np.abs(x_at_layer))
    hw = delta / 2 * l1_norm * np.ones(3)
    
    # Vertices of this box
    vertices = get_hypercube_vertices(1.0, dims=3) * hw
    
    return vertices, hw


def trace_error_geometry(x, quant_weights, delta):
    """
    Trace the error region geometry through layers.
    
    Key insight: at each layer, we add a new error box (from that layer's 
    weight quantization), then transform the cumulative error by the next layer's weights.
    
    Returns history of error region vertices at each layer.
    """
    history = []
    
    val = x.copy()
    cumulative_transform = np.eye(3)  # Maps current error space back to input space
    
    # Track error region vertices (in input space)
    all_error_vertices = []  # List of vertex sets, one per layer
    
    for i, W in enumerate(quant_weights):
        # Error introduced at this layer (in this layer's output space)
        error_vertices_local, hw_local = compute_error_region_vertices(val, W, delta)
        
        # Transform to input space using inverse of cumulative transform so far
        # After this layer, cumulative transform becomes W @ cumulative_transform
        cumulative_transform_after = W @ cumulative_transform
        
        # Map local error to input space
        # Local error is in output space of this layer
        # To map to input space: multiply by inverse of cumulative_transform_after
        try:
            inv_transform = np.linalg.inv(cumulative_transform_after)
            error_vertices_input_space = transform_vertices(error_vertices_local, inv_transform)
        except np.linalg.LinAlgError:
            error_vertices_input_space = error_vertices_local  # Fallback
        
        all_error_vertices.append(error_vertices_input_space)
        
        # Compute Minkowski sum of all error regions so far
        if i == 0:
            total_vertices = error_vertices_input_space
        else:
            # Minkowski sum: all pairwise sums of vertices
            total_vertices = minkowski_sum_vertices(total_vertices, error_vertices_input_space)
        
        history.append({
            'layer': i + 1,
            'value': val.copy(),
            'W': W.copy(),
            'error_vertices_local': error_vertices_local.copy(),
            'error_vertices_input_space': error_vertices_input_space.copy(),
            'cumulative_error_vertices': total_vertices.copy(),
            'cumulative_transform': cumulative_transform_after.copy(),
            'hw_local': hw_local.copy()
        })
        
        # Update for next layer
        val = W @ val
        cumulative_transform = cumulative_transform_after
    
    return history


def minkowski_sum_vertices(V1, V2):
    """
    Compute Minkowski sum of two vertex sets.
    Result vertices are all pairwise sums.
    Then take convex hull to get the actual vertices.
    """
    sums = []
    for v1 in V1:
        for v2 in V2:
            sums.append(v1 + v2)
    sums = np.array(sums)
    
    # Take convex hull to reduce to actual vertices
    if len(sums) > 4:  # Need at least 4 points for 3D hull
        try:
            hull = ConvexHull(sums)
            return sums[hull.vertices]
        except:
            return sums
    return sums


history = trace_error_geometry(x_input, quant_weights, delta)

print("\n\nError geometry through layers:")
print("=" * 70)
for h in history:
    print(f"\nLayer {h['layer']}:")
    print(f"  Value at layer input: [{h['value'][0]:.3f}, {h['value'][1]:.3f}, {h['value'][2]:.3f}]")
    print(f"  Local error box half-widths: {h['hw_local'][0]:.6f} (same for all dims)")
    print(f"  Error vertices in input space: {len(h['error_vertices_input_space'])} vertices")
    print(f"  Cumulative error vertices: {len(h['cumulative_error_vertices'])} vertices")
    
    # Compute bounding box of cumulative error
    cum_verts = h['cumulative_error_vertices']
    bbox_min = cum_verts.min(axis=0)
    bbox_max = cum_verts.max(axis=0)
    print(f"  Cumulative error bounding box:")
    print(f"    Ch0: [{bbox_min[0]:.6f}, {bbox_max[0]:.6f}]")
    print(f"    Ch1: [{bbox_min[1]:.6f}, {bbox_max[1]:.6f}]")
    print(f"    Ch2: [{bbox_min[2]:.6f}, {bbox_max[2]:.6f}]")


# ============================================================
# Visualization
# ============================================================

def draw_vertices_and_hull(ax, vertices, color, alpha=0.3, label=None):
    """Draw vertices and their convex hull"""
    ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], 
               c=color, s=20, alpha=0.8)
    
    if len(vertices) >= 4:
        try:
            hull = ConvexHull(vertices)
            for simplex in hull.simplices:
                triangle = vertices[simplex]
                ax.add_collection3d(Poly3DCollection(
                    [triangle], alpha=alpha, facecolor=color, edgecolor='black', linewidth=0.5
                ))
        except:
            pass


def draw_wireframe_box(ax, half_widths, color='gray', alpha=0.5):
    """Draw wireframe of axis-aligned box for reference"""
    hw = half_widths
    # 12 edges of a box
    edges = []
    for i in [-1, 1]:
        for j in [-1, 1]:
            edges.append([[-hw[0]*i, -hw[1]*j, -hw[2]], [-hw[0]*i, -hw[1]*j, hw[2]]])
            edges.append([[-hw[0]*i, -hw[1], -hw[2]*j], [-hw[0]*i, hw[1], -hw[2]*j]])
            edges.append([[-hw[0], -hw[1]*i, -hw[2]*j], [hw[0], -hw[1]*i, -hw[2]*j]])
    
    for edge in edges:
        ax.plot3D(*zip(*edge), color=color, alpha=alpha, linewidth=1)


# Figure 1: Error region evolution through layers
fig = plt.figure(figsize=(18, 5))

colors = plt.cm.viridis(np.linspace(0.2, 0.8, n_layers))

for i, (h, color) in enumerate(zip(history, colors)):
    ax = fig.add_subplot(1, 4, i+1, projection='3d')
    
    # Draw this layer's error contribution (in input space)
    draw_vertices_and_hull(ax, h['error_vertices_input_space'], color, alpha=0.4)
    
    # Draw cumulative error region
    draw_vertices_and_hull(ax, h['cumulative_error_vertices'], 'red', alpha=0.2)
    
    ax.set_xlabel('Ch0')
    ax.set_ylabel('Ch1')
    ax.set_zlabel('Ch2')
    ax.set_title(f"Layer {h['layer']}\n{len(h['cumulative_error_vertices'])} vertices")
    
    # Set consistent scale
    max_extent = np.abs(h['cumulative_error_vertices']).max() * 1.2
    ax.set_xlim(-max_extent, max_extent)
    ax.set_ylim(-max_extent, max_extent)
    ax.set_zlim(-max_extent, max_extent)

plt.tight_layout()
plt.savefig('plots/experiment3_error_evolution.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 2: Compare final error region to axis-aligned box
fig = plt.figure(figsize=(16, 6))

final_vertices = history[-1]['cumulative_error_vertices']
bbox_hw = (final_vertices.max(axis=0) - final_vertices.min(axis=0)) / 2

# Plot 1: Final error region (the actual parallelepiped-ish shape)
ax1 = fig.add_subplot(131, projection='3d')
draw_vertices_and_hull(ax1, final_vertices, 'red', alpha=0.4)
ax1.set_xlabel('Ch0')
ax1.set_ylabel('Ch1')
ax1.set_zlabel('Ch2')
ax1.set_title('Actual error region\n(Non-axis-aligned)')
max_extent = np.abs(final_vertices).max() * 1.2
ax1.set_xlim(-max_extent, max_extent)
ax1.set_ylim(-max_extent, max_extent)
ax1.set_zlim(-max_extent, max_extent)

# Plot 2: Bounding box (axis-aligned approximation)
ax2 = fig.add_subplot(132, projection='3d')
box_vertices = get_hypercube_vertices(1.0) * bbox_hw
draw_vertices_and_hull(ax2, box_vertices, 'blue', alpha=0.4)
ax2.set_xlabel('Ch0')
ax2.set_ylabel('Ch1')
ax2.set_zlabel('Ch2')
ax2.set_title('Bounding box\n(Axis-aligned approximation)')
ax2.set_xlim(-max_extent, max_extent)
ax2.set_ylim(-max_extent, max_extent)
ax2.set_zlim(-max_extent, max_extent)

# Plot 3: Overlay
ax3 = fig.add_subplot(133, projection='3d')
draw_vertices_and_hull(ax3, final_vertices, 'red', alpha=0.3)
draw_wireframe_box(ax3, bbox_hw, 'blue', alpha=0.8)
ax3.set_xlabel('Ch0')
ax3.set_ylabel('Ch1')
ax3.set_zlabel('Ch2')
ax3.set_title('Overlay\nRed=Actual, Blue wireframe=Bounding box')
ax3.set_xlim(-max_extent, max_extent)
ax3.set_ylim(-max_extent, max_extent)
ax3.set_zlim(-max_extent, max_extent)

plt.tight_layout()
plt.savefig('plots/experiment3_final_comparison.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 3: Analyze the shape via PCA / SVD
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Compute SVD of the final error region
centered = final_vertices - final_vertices.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)

print("\n\nSVD analysis of final error region:")
print("-" * 50)
print(f"Singular values: [{S[0]:.6f}, {S[1]:.6f}, {S[2]:.6f}]")
print(f"Condition number: {S[0]/S[2]:.4f}")
print(f"Principal directions:")
for i, v in enumerate(Vt):
    print(f"  PC{i+1}: [{v[0]:.4f}, {v[1]:.4f}, {v[2]:.4f}]")

# Plot 1: Singular values (shape elongation)
ax = axes[0, 0]
ax.bar(range(3), S, color=['red', 'green', 'blue'], alpha=0.7)
ax.set_xticks(range(3))
ax.set_xticklabels(['PC1', 'PC2', 'PC3'])
ax.set_ylabel('Singular value')
ax.set_title('Error region principal components\n(Higher = more spread in that direction)')
ax.grid(True, alpha=0.3)

# Plot 2: Compare to diagonal case
# Recompute with diagonal weights for comparison
diagonal_weights = [np.diag(np.diag(W)) for W in quant_weights]
history_diag = trace_error_geometry(x_input, diagonal_weights, delta)
final_vertices_diag = history_diag[-1]['cumulative_error_vertices']
centered_diag = final_vertices_diag - final_vertices_diag.mean(axis=0)
U_diag, S_diag, Vt_diag = np.linalg.svd(centered_diag, full_matrices=False)

ax = axes[0, 1]
x_pos = np.arange(3)
width = 0.35
ax.bar(x_pos - width/2, S, width, label='Full matrix', color='red', alpha=0.7)
ax.bar(x_pos + width/2, S_diag, width, label='Diagonal only', color='blue', alpha=0.7)
ax.set_xticks(x_pos)
ax.set_xticklabels(['PC1', 'PC2', 'PC3'])
ax.set_ylabel('Singular value')
ax.set_title('Shape comparison: Full vs Diagonal weights')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Volume comparison through layers
ax = axes[1, 0]
volumes_full = []
volumes_diag = []

for h_full, h_diag in zip(history, history_diag):
    # Approximate volume using convex hull
    try:
        hull_full = ConvexHull(h_full['cumulative_error_vertices'])
        vol_full = hull_full.volume
    except:
        vol_full = 0
    try:
        hull_diag = ConvexHull(h_diag['cumulative_error_vertices'])
        vol_diag = hull_diag.volume
    except:
        vol_diag = 0
    volumes_full.append(vol_full)
    volumes_diag.append(vol_diag)

layers = [h['layer'] for h in history]
ax.plot(layers, volumes_full, 'o-', linewidth=2, markersize=8, label='Full matrix', color='red')
ax.plot(layers, volumes_diag, 's-', linewidth=2, markersize=8, label='Diagonal only', color='blue')
ax.set_xlabel('Layer')
ax.set_ylabel('Error region volume')
ax.set_title('Error volume growth\n(Full matrices can have different volume than diagonal)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Bounding box inefficiency
# How much of the bounding box is "wasted" (not part of actual error region)?
ax = axes[1, 1]
bbox_volumes = []
actual_volumes = []

for h in history:
    verts = h['cumulative_error_vertices']
    bbox = np.prod(verts.max(axis=0) - verts.min(axis=0))
    bbox_volumes.append(bbox)
    try:
        hull = ConvexHull(verts)
        actual_volumes.append(hull.volume)
    except:
        actual_volumes.append(bbox)

efficiency = [a/b if b > 0 else 1 for a, b in zip(actual_volumes, bbox_volumes)]

ax.bar(layers, efficiency, color='purple', alpha=0.7)
ax.axhline(1.0, color='gray', linestyle='--', label='Perfect efficiency (cube)')
ax.set_xlabel('Layer')
ax.set_ylabel('Volume efficiency (actual / bounding box)')
ax.set_title('Bounding box efficiency\n(<1 means error region is tilted/elongated)')
ax.set_ylim(0, 1.2)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/experiment3_shape_analysis.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 4: 2D projections to see the tilt
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

projections = [(0, 1, 'Ch0', 'Ch1'), (0, 2, 'Ch0', 'Ch2'), (1, 2, 'Ch1', 'Ch2')]

for ax, (i, j, xlabel, ylabel) in zip(axes, projections):
    # Project final vertices
    proj_full = final_vertices[:, [i, j]]
    proj_diag = final_vertices_diag[:, [i, j]]
    
    # Draw convex hulls
    try:
        hull_full = ConvexHull(proj_full)
        hull_diag = ConvexHull(proj_diag)
        
        # Full matrix
        for simplex in hull_full.simplices:
            ax.plot(proj_full[simplex, 0], proj_full[simplex, 1], 'r-', linewidth=2)
        ax.fill(proj_full[hull_full.vertices, 0], proj_full[hull_full.vertices, 1], 
                'red', alpha=0.3, label='Full matrix')
        
        # Diagonal
        for simplex in hull_diag.simplices:
            ax.plot(proj_diag[simplex, 0], proj_diag[simplex, 1], 'b--', linewidth=2)
        ax.fill(proj_diag[hull_diag.vertices, 0], proj_diag[hull_diag.vertices, 1], 
                'blue', alpha=0.2, label='Diagonal only')
    except:
        ax.scatter(proj_full[:, 0], proj_full[:, 1], c='red', alpha=0.5)
        ax.scatter(proj_diag[:, 0], proj_diag[:, 1], c='blue', alpha=0.5)
    
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(f'Projection onto {xlabel}-{ylabel} plane')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')
    ax.axhline(0, color='k', linewidth=0.5)
    ax.axvline(0, color='k', linewidth=0.5)

plt.tight_layout()
plt.savefig('plots/experiment3_projections.png', dpi=150, bbox_inches='tight')
plt.show()


# ============================================================
# Key observations
# ============================================================

print("\n" + "=" * 70)
print("KEY OBSERVATIONS - EXPERIMENT 3")
print("=" * 70)

print(f"""
1. ERROR REGION IS NO LONGER AXIS-ALIGNED
   - Off-diagonal weights cause channels to mix
   - Error in one channel propagates to others
   - The error region becomes a parallelepiped (or more complex polytope)

2. THE SHAPE IS TILTED
   - Principal components are not aligned with coordinate axes
   - PC1 direction: [{Vt[0,0]:.3f}, {Vt[0,1]:.3f}, {Vt[0,2]:.3f}]
   - This means error is correlated across channels

3. BOUNDING BOX OVERESTIMATES ERROR
   - Actual volume / Bounding box volume = {efficiency[-1]:.3f}
   - Using axis-aligned bounds would overestimate error by {100*(1/efficiency[-1] - 1):.1f}%
   - The true error region is more constrained than the box suggests

4. SINGULAR VALUES REVEAL ERROR ANISOTROPY
   - Largest SV: {S[0]:.6f} (most error in this direction)
   - Smallest SV: {S[2]:.6f} (least error in this direction)
   - Ratio: {S[0]/S[2]:.2f}x (error is {S[0]/S[2]:.1f}x larger in worst vs best direction)

5. COMPARISON TO DIAGONAL CASE
   - Full matrix volume: {volumes_full[-1]:.6f}
   - Diagonal-only volume: {volumes_diag[-1]:.6f}
   - Ratio: {volumes_full[-1]/volumes_diag[-1]:.3f}x
   
6. IMPLICATIONS
   - Can't analyze channels independently when weights mix them
   - Error correction needs to account for correlations
   - The "worst case" error direction may not align with any single channel
   - PCA of error region tells you where to focus correction efforts
""")

In [None]:
# Experiment 4: Multiple Input Points - Error Manifold
# ====================================================
# 
# Setup:
# - Multiple input points (not just one)
# - See how error region varies across input space
# - Connect input geometry to error geometry
#
# Key question: How does the "error manifold" relate to the "input manifold"?

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull

# ============================================================
# Setup - Keep it simple: 2D for visualization
# ============================================================

np.random.seed(42)

# Quantization
bits = 8
delta = 1.0 / (2 ** (bits - 1))

# Network: 3 layers, 2D -> 2D
# Mix of diagonal and off-diagonal to show both effects
true_weights = [
    np.array([[0.9, 0.2],
              [0.1, 1.1]]),
    np.array([[1.1, -0.1],
              [0.2, 0.8]]),
    np.array([[0.85, 0.15],
              [-0.1, 1.0]]),
]

def quantize(W, delta):
    return np.round(W / delta) * delta

quant_weights = [quantize(W, delta) for W in true_weights]

print("Network configuration (2D):")
print("-" * 40)
for i, W in enumerate(quant_weights):
    print(f"Layer {i+1}:\n{W}\n")


# ============================================================
# Core: compute error for a single input point
# ============================================================

def compute_error_half_width(x, quant_weights, delta):
    """
    Compute the error half-width at each layer for input x,
    mapped back to input space.
    
    Returns total error half-widths (as a 2D vector for the bounding box).
    """
    val = x.copy()
    cumulative_W = np.eye(2)
    
    total_error_vertices = None
    
    for W in quant_weights:
        # Error at this layer: each output is sum of weight_err * input
        # For 2D: output_err[i] = sum_j W_err[i,j] * val[j]
        # Max error per output dim = delta/2 * sum|val|
        l1_norm = np.sum(np.abs(val))
        local_hw = delta / 2 * l1_norm
        
        # This gives an axis-aligned box in output space
        # Vertices of local error box
        local_vertices = np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * local_hw
        
        # Map to input space via inverse of cumulative transform
        cumulative_W = W @ cumulative_W
        try:
            inv_W = np.linalg.inv(cumulative_W)
            local_vertices_input = local_vertices @ inv_W.T
        except:
            local_vertices_input = local_vertices
        
        # Minkowski sum
        if total_error_vertices is None:
            total_error_vertices = local_vertices_input
        else:
            # Pairwise sums
            new_vertices = []
            for v1 in total_error_vertices:
                for v2 in local_vertices_input:
                    new_vertices.append(v1 + v2)
            new_vertices = np.array(new_vertices)
            # Convex hull to simplify
            try:
                hull = ConvexHull(new_vertices)
                total_error_vertices = new_vertices[hull.vertices]
            except:
                total_error_vertices = new_vertices
        
        # Update value for next layer
        val = W @ val
    
    return total_error_vertices


def compute_error_magnitude(x, quant_weights, delta):
    """Compute scalar error magnitude (max extent) for a point."""
    vertices = compute_error_half_width(x, quant_weights, delta)
    return np.max(np.linalg.norm(vertices, axis=1))


# ============================================================
# Generate input points on different manifolds
# ============================================================

# Manifold 1: Circle
n_points = 50
theta = np.linspace(0, 2*np.pi, n_points, endpoint=False)
radius = 20
circle_points = np.column_stack([radius * np.cos(theta), radius * np.sin(theta)])

# Manifold 2: Line segment
line_points = np.column_stack([
    np.linspace(-30, 30, n_points),
    np.linspace(-10, 10, n_points)
])

# Manifold 3: Grid (to see error variation across 2D space)
grid_1d = np.linspace(-30, 30, 15)
grid_x, grid_y = np.meshgrid(grid_1d, grid_1d)
grid_points = np.column_stack([grid_x.ravel(), grid_y.ravel()])


# ============================================================
# Compute errors for all points
# ============================================================

print("Computing error regions for input manifolds...")

# For circle
circle_errors = [compute_error_magnitude(p, quant_weights, delta) for p in circle_points]
circle_vertices = [compute_error_half_width(p, quant_weights, delta) for p in circle_points]

# For line
line_errors = [compute_error_magnitude(p, quant_weights, delta) for p in line_points]

# For grid
grid_errors = np.array([compute_error_magnitude(p, quant_weights, delta) for p in grid_points])
grid_errors_2d = grid_errors.reshape(len(grid_1d), len(grid_1d))


# ============================================================
# Visualization
# ============================================================

# Figure 1: Error magnitude varies with input position
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: Circle - colored by error magnitude
ax = axes[0]
scatter = ax.scatter(circle_points[:, 0], circle_points[:, 1], 
                     c=circle_errors, cmap='hot', s=100, edgecolors='black')
plt.colorbar(scatter, ax=ax, label='Error magnitude')
ax.set_xlabel('Input dim 0')
ax.set_ylabel('Input dim 1')
ax.set_title('Circle manifold\nColor = error magnitude')
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)

# Plot 2: Line - error vs position
ax = axes[1]
positions = np.linalg.norm(line_points, axis=1) * np.sign(line_points[:, 0])
ax.plot(positions, line_errors, 'o-', linewidth=2, markersize=6)
ax.set_xlabel('Position along line')
ax.set_ylabel('Error magnitude')
ax.set_title('Line manifold\nError grows with distance from origin')
ax.grid(True, alpha=0.3)

# Plot 3: Grid - heatmap
ax = axes[2]
im = ax.imshow(grid_errors_2d, extent=[-30, 30, -30, 30], origin='lower', cmap='hot')
plt.colorbar(im, ax=ax, label='Error magnitude')
ax.set_xlabel('Input dim 0')
ax.set_ylabel('Input dim 1')
ax.set_title('Error magnitude across input space\nBrighter = more error')
ax.set_aspect('equal')

plt.tight_layout()
plt.savefig('plots/experiment4_error_magnitude.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 2: Error regions for selected points on the circle
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Select 8 points around the circle
selected_indices = np.linspace(0, n_points-1, 8, dtype=int)

for idx, (ax, i) in enumerate(zip(axes.flat, selected_indices)):
    point = circle_points[i]
    vertices = circle_vertices[i]
    
    # Draw error region centered at origin
    if len(vertices) >= 3:
        hull = ConvexHull(vertices)
        hull_vertices = vertices[hull.vertices]
        hull_vertices = np.vstack([hull_vertices, hull_vertices[0]])  # Close polygon
        ax.fill(hull_vertices[:, 0], hull_vertices[:, 1], 'red', alpha=0.3)
        ax.plot(hull_vertices[:, 0], hull_vertices[:, 1], 'r-', linewidth=2)
    
    ax.scatter([0], [0], c='black', s=50, zorder=5)
    
    ax.set_title(f'Input: ({point[0]:.0f}, {point[1]:.0f})\nError: {circle_errors[i]:.4f}')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    
    # Set consistent scale
    max_err = max(circle_errors) * 1.2
    ax.set_xlim(-max_err, max_err)
    ax.set_ylim(-max_err, max_err)
    ax.axhline(0, color='k', linewidth=0.5)
    ax.axvline(0, color='k', linewidth=0.5)

plt.suptitle('Error regions for points around the circle\n(Shape changes with input direction!)', fontsize=12)
plt.tight_layout()
plt.savefig('plots/experiment4_error_regions_circle.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 3: The key insight - error magnitude vs input magnitude
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Error vs L1 norm of input
ax = axes[0]
l1_norms = np.sum(np.abs(grid_points), axis=1)
ax.scatter(l1_norms, grid_errors, alpha=0.5, s=20)

# Fit line
z = np.polyfit(l1_norms, grid_errors, 1)
p = np.poly1d(z)
x_fit = np.linspace(l1_norms.min(), l1_norms.max(), 100)
ax.plot(x_fit, p(x_fit), 'r-', linewidth=2, label=f'Linear fit: y = {z[0]:.4f}x + {z[1]:.4f}')

ax.set_xlabel('L1 norm of input (|x₀| + |x₁|)')
ax.set_ylabel('Error magnitude')
ax.set_title('Error scales with input L1 norm\n(But with scatter due to direction)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Error vs L2 norm
ax = axes[1]
l2_norms = np.linalg.norm(grid_points, axis=1)
ax.scatter(l2_norms, grid_errors, alpha=0.5, s=20, c=np.arctan2(grid_points[:,1], grid_points[:,0]), cmap='hsv')

ax.set_xlabel('L2 norm of input (distance from origin)')
ax.set_ylabel('Error magnitude')
ax.set_title('Error vs distance from origin\nColor = angle (shows directional dependence)')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/experiment4_error_vs_norm.png', dpi=150, bbox_inches='tight')
plt.show()


# Figure 4: Error region shape varies with direction
fig, ax = plt.subplots(1, 1, figsize=(10, 10))

# Draw all error regions overlaid, centered at their input points (scaled down)
scale = 0.3  # Scale factor to make error regions visible

for i, (point, vertices) in enumerate(zip(circle_points, circle_vertices)):
    if len(vertices) >= 3:
        try:
            hull = ConvexHull(vertices)
            hull_vertices = vertices[hull.vertices] * scale + point
            hull_vertices = np.vstack([hull_vertices, hull_vertices[0]])
            ax.fill(hull_vertices[:, 0], hull_vertices[:, 1], 'red', alpha=0.2)
            ax.plot(hull_vertices[:, 0], hull_vertices[:, 1], 'r-', linewidth=0.5)
        except:
            pass

# Draw the circle
ax.plot(circle_points[:, 0], circle_points[:, 1], 'b-', linewidth=2, label='Input manifold (circle)')
ax.scatter(circle_points[:, 0], circle_points[:, 1], c='blue', s=30, zorder=5)

ax.set_xlabel('Input dim 0')
ax.set_ylabel('Input dim 1')
ax.set_title('Input manifold with error regions\n(Red shapes show error region at each point, scaled)')
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.legend()

plt.tight_layout()
plt.savefig('plots/experiment4_manifold_with_errors.png', dpi=150, bbox_inches='tight')
plt.show()


# ============================================================
# Key observations
# ============================================================

print("\n" + "=" * 70)
print("KEY OBSERVATIONS - EXPERIMENT 4")
print("=" * 70)

# Compute some statistics
error_mean = np.mean(circle_errors)
error_std = np.std(circle_errors)
error_min = np.min(circle_errors)
error_max = np.max(circle_errors)

print(f"""
1. ERROR MAGNITUDE SCALES WITH INPUT MAGNITUDE
   - Points further from origin have larger error
   - This is because error = weight_error × value
   - Larger values → larger absolute error

2. ERROR MAGNITUDE DEPENDS ON DIRECTION
   - Points at same distance but different angles have different errors
   - On the circle (constant radius), error varies from {error_min:.4f} to {error_max:.4f}
   - Ratio: {error_max/error_min:.2f}x variation
   - This is due to non-diagonal weights creating directional preference

3. ERROR REGION SHAPE VARIES WITH INPUT
   - Not just magnitude — the shape (orientation) changes
   - Error region "tilts" based on input direction
   - This comes from how the cumulative weight transform depends on path through network

4. THE ERROR MANIFOLD
   - Input manifold: circle of radius {radius}
   - Error at each point creates a "tube" around the manifold
   - The tube thickness varies with position
   - The tube cross-section shape varies with position

5. IMPLICATIONS FOR QUANTIZATION
   - Some directions in input space are "safer" than others
   - Could potentially transform data to align with low-error directions
   - For classification: if decision boundary aligns with high-error direction, more mistakes

6. RELATIONSHIP: ERROR ~ k × L1_NORM (approximately)
   - Linear fit slope: {z[0]:.4f}
   - But there's scatter due to directional effects
   - L1 norm matters because error = delta × sum(|weights| × |inputs|)
""")

# Verify the L1 relationship
print("\nVerification - error vs L1 norm for circle points:")
circle_l1 = np.sum(np.abs(circle_points), axis=1)
correlation = np.corrcoef(circle_l1, circle_errors)[0, 1]
print(f"  Correlation between L1 norm and error: {correlation:.4f}")
print(f"  (Not perfect because direction also matters)")