# EM-Style Reservoir Training for INR

**Key Idea**: Alternate between training the readout and updating the reservoir.

Traditional RC: Random fixed reservoir → train only readout

EM-RC: Iterate until convergence:
- **E-step**: Fix reservoir weights → solve for optimal readout (ridge regression)
- **M-step**: Fix readout → update reservoir weights to reduce reconstruction error

This simulates an **evolving reservoir** that adapts when improvement stalls.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

try:
    import torch
    import torch.nn as nn
    USE_TORCH = True
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"PyTorch available, using {device}")
except ImportError:
    USE_TORCH = False
    print("PyTorch not available, using NumPy only")

np.random.seed(42)

In [None]:
# Load image
img = Image.open('fig/cat.png').convert('RGB')
target_size = 128
img = img.resize((target_size, target_size), Image.LANCZOS)
img_array = np.array(img) / 255.0

h, w, c = img_array.shape
coords = np.linspace(0, 1, h, endpoint=False)
x_grid = np.stack(np.meshgrid(coords, coords), -1)
X = x_grid.reshape(-1, 2)
Y = img_array.reshape(-1, 3)

print(f"Image: {h}x{w}, Samples: {len(X)}")
plt.imshow(img_array)
plt.title('Target')
plt.axis('off')
plt.show()

## EM-Reservoir: NumPy Implementation

```
Initialize: Random W_in, W_hh, b

Repeat until convergence:
    E-step: h = reservoir(X; W_in, W_hh, b)
            W_out = ridge_solve(h, Y)  # Optimal readout
            
    M-step: Compute gradient of loss w.r.t. W_in, W_hh, b
            Update reservoir weights via gradient descent
```

In [None]:
class EMReservoir:
    """
    EM-style reservoir where both reservoir and readout are trained alternately.
    """
    def __init__(self, input_dim, hidden_size, output_dim, 
                 num_layers=1, spectral_radius=0.9):
        self.input_dim = input_dim
        self.hidden_size = hidden_size
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.spectral_radius = spectral_radius
        
        # Initialize reservoir weights
        self._init_reservoir()
        
        # Readout weights (will be computed in E-step)
        self.W_out = None
        
    def _init_reservoir(self):
        """Initialize reservoir weights."""
        self.layers = []
        d_in = self.input_dim
        
        for l in range(self.num_layers):
            W_in = np.random.randn(d_in, self.hidden_size) * 0.5
            W_hh = np.random.randn(self.hidden_size, self.hidden_size)
            # Normalize spectral radius
            eig = np.abs(np.linalg.eigvals(W_hh)).max()
            W_hh = W_hh * (self.spectral_radius / eig)
            b = np.random.randn(self.hidden_size) * 0.1
            
            self.layers.append({'W_in': W_in, 'W_hh': W_hh, 'b': b})
            d_in = self.hidden_size
    
    def forward(self, X, num_iterations=10):
        """Forward pass through reservoir."""
        n = X.shape[0]
        all_h = []
        
        layer_input = X
        for layer in self.layers:
            W_in, W_hh, b = layer['W_in'], layer['W_hh'], layer['b']
            h = np.zeros((n, self.hidden_size))
            
            # Iterative settling
            for _ in range(num_iterations):
                h = np.tanh(layer_input @ W_in + h @ W_hh + b)
            
            all_h.append(h)
            layer_input = h
        
        # Concatenate all layer outputs
        H = np.hstack(all_h)
        return H
    
    def e_step(self, X, Y, lamb=1e-6):
        """
        E-step: Fix reservoir, compute optimal readout via ridge regression.
        """
        H = self.forward(X)
        # Ridge regression: W_out = (H^T H + λI)^(-1) H^T Y
        self.W_out = np.linalg.solve(
            H.T @ H + lamb * np.eye(H.shape[1]), 
            H.T @ Y
        )
        return H
    
    def predict(self, X):
        """Predict output."""
        H = self.forward(X)
        return np.clip(H @ self.W_out, 0, 1)
    
    def compute_loss(self, X, Y):
        """Compute MSE loss."""
        pred = self.predict(X)
        mse = np.mean((pred - Y) ** 2)
        psnr = -10 * np.log10(mse) if mse > 0 else 100
        return mse, psnr

# Test
em_res = EMReservoir(input_dim=2, hidden_size=256, output_dim=3, num_layers=3)
H = em_res.e_step(X, Y)
mse, psnr = em_res.compute_loss(X, Y)
print(f"Initial (random reservoir): PSNR = {psnr:.2f} dB")

## M-Step: Update Reservoir Weights

Use gradient descent to update reservoir weights while keeping readout fixed.

We need gradients through the reservoir, so we'll use PyTorch for this.

In [None]:
if USE_TORCH:
    class EMReservoirTorch(nn.Module):
        """
        EM-style reservoir with trainable weights.
        """
        def __init__(self, input_dim, hidden_size, output_dim,
                     num_layers=1, spectral_radius=0.9, num_iterations=10):
            super().__init__()
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.num_iterations = num_iterations
            self.spectral_radius = spectral_radius

            # Reservoir weights (trainable in M-step)
            self.W_in = nn.ParameterList()
            self.W_hh = nn.ParameterList()
            self.b = nn.ParameterList()

            d_in = input_dim
            for l in range(num_layers):
                self.W_in.append(nn.Parameter(torch.randn(d_in, hidden_size) * 0.5))

                # Initialize W_hh with spectral radius control
                W_hh_init = torch.randn(hidden_size, hidden_size)
                eig = torch.linalg.eigvals(W_hh_init).abs().max()
                W_hh_init = W_hh_init * (spectral_radius / eig)
                self.W_hh.append(nn.Parameter(W_hh_init))

                self.b.append(nn.Parameter(torch.randn(hidden_size) * 0.1))
                d_in = hidden_size

            # Readout weights (computed in E-step, not trained by gradient)
            self.register_buffer('W_out', torch.zeros(num_layers * hidden_size, output_dim))

        def forward(self, X):
            """Forward pass through reservoir."""
            batch_size = X.shape[0]
            all_h = []

            layer_input = X
            for l in range(self.num_layers):
                h = torch.zeros(batch_size, self.hidden_size, device=X.device)

                for _ in range(self.num_iterations):
                    h = torch.tanh(layer_input @ self.W_in[l] + h @ self.W_hh[l] + self.b[l])

                all_h.append(h)
                layer_input = h

            H = torch.cat(all_h, dim=1)
            return H

        def e_step(self, X, Y, lamb=1e-3):
            """
            E-step: Compute optimal readout (closed-form ridge regression).
            Uses lstsq for numerical stability.
            """
            with torch.no_grad():
                H = self.forward(X)
                # Add regularization directly to H for stability
                n, d = H.shape
                H_reg = torch.cat([H, torch.sqrt(torch.tensor(lamb, device=H.device)) * torch.eye(d, device=H.device)], dim=0)
                Y_reg = torch.cat([Y, torch.zeros(d, Y.shape[1], device=Y.device)], dim=0)
                # Use lstsq (more stable than solve)
                result = torch.linalg.lstsq(H_reg, Y_reg)
                self.W_out = result.solution
            return H

        def predict(self, X):
            """Predict with current reservoir and readout."""
            H = self.forward(X)
            return torch.clamp(H @ self.W_out, 0, 1)

        def m_step_loss(self, X, Y):
            """
            M-step loss: reconstruction error with FIXED readout.
            Gradients flow through reservoir weights only.
            """
            H = self.forward(X)
            # Use detached W_out (no gradient through it)
            pred = torch.clamp(H @ self.W_out.detach(), 0, 1)
            return torch.mean((pred - Y) ** 2)

    print("EMReservoirTorch defined")
else:
    print("PyTorch not available")

## EM Training Loop

In [None]:
if USE_TORCH:
    def train_em_reservoir(X, Y, hidden_size=256, num_layers=3,
                           num_em_iterations=20, m_steps_per_em=100,
                           lr=1e-3, patience=3):
        """
        EM-style training:
        1. E-step: Fix reservoir, solve for optimal readout
        2. M-step: Fix readout, update reservoir via gradient descent
        3. Repeat until convergence
        """
        X_t = torch.FloatTensor(X).to(device)
        Y_t = torch.FloatTensor(Y).to(device)

        model = EMReservoirTorch(
            input_dim=2, hidden_size=hidden_size, output_dim=3,
            num_layers=num_layers, num_iterations=10
        ).to(device)

        # Optimizer for reservoir weights only
        reservoir_params = list(model.W_in) + list(model.W_hh) + list(model.b)
        optimizer = torch.optim.Adam(reservoir_params, lr=lr)

        history = {'em_iter': [], 'psnr': [], 'mse': []}
        best_psnr = 0
        no_improve_count = 0

        print("EM Training:")
        print("=" * 60)

        for em_iter in range(num_em_iterations):
            # ============ E-STEP ============
            model.e_step(X_t, Y_t)

            with torch.no_grad():
                pred = model.predict(X_t)
                mse = torch.mean((pred - Y_t) ** 2).item()
                psnr = -10 * np.log10(mse) if mse > 0 else 100

            print(f"EM Iter {em_iter+1:2d} | After E-step: PSNR = {psnr:.2f} dB")

            history['em_iter'].append(em_iter + 0.5)  # Mark E-step
            history['psnr'].append(psnr)
            history['mse'].append(mse)

            # ============ M-STEP ============
            # Update reservoir weights to reduce reconstruction error
            for m_step in range(m_steps_per_em):
                optimizer.zero_grad()
                loss = model.m_step_loss(X_t, Y_t)
                loss.backward()
                # Gradient clipping for stability
                torch.nn.utils.clip_grad_norm_(reservoir_params, max_norm=1.0)
                optimizer.step()

            with torch.no_grad():
                pred = model.predict(X_t)
                mse = torch.mean((pred - Y_t) ** 2).item()
                psnr_after_m = -10 * np.log10(mse) if mse > 0 else 100

            print(f"           | After M-step: PSNR = {psnr_after_m:.2f} dB (Δ={psnr_after_m-psnr:+.2f})")

            history['em_iter'].append(em_iter + 1)
            history['psnr'].append(psnr_after_m)
            history['mse'].append(mse)

            # Early stopping
            if psnr_after_m > best_psnr + 0.1:
                best_psnr = psnr_after_m
                no_improve_count = 0
            else:
                no_improve_count += 1

            if no_improve_count >= patience:
                print(f"\nConverged after {em_iter+1} EM iterations")
                break

        return model, history

    print("Training function defined")

In [None]:
if USE_TORCH:
    # Train EM-Reservoir
    model, history = train_em_reservoir(
        X, Y, 
        hidden_size=256, 
        num_layers=3,
        num_em_iterations=20,
        m_steps_per_em=200,
        lr=1e-3,
        patience=5
    )

In [None]:
if USE_TORCH:
    # Visualize training progress
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # PSNR over EM iterations
    ax = axes[0]
    ax.plot(history['em_iter'], history['psnr'], 'bo-', markersize=6)
    ax.set_xlabel('EM Iteration')
    ax.set_ylabel('PSNR (dB)')
    ax.set_title('EM Training Progress')
    ax.grid(True, alpha=0.3)
    
    # Mark E-steps and M-steps
    for i, (x, p) in enumerate(zip(history['em_iter'], history['psnr'])):
        if x == int(x):  # M-step result
            ax.annotate('M', (x, p), textcoords="offset points", xytext=(0,10), fontsize=8)
        else:  # E-step result
            ax.annotate('E', (x, p), textcoords="offset points", xytext=(0,10), fontsize=8)
    
    # Final image
    ax = axes[1]
    with torch.no_grad():
        X_t = torch.FloatTensor(X).to(device)
        pred = model.predict(X_t).cpu().numpy()
    ax.imshow(pred.reshape(h, w, 3))
    final_psnr = history['psnr'][-1]
    ax.set_title(f'EM-Reservoir Result\nPSNR = {final_psnr:.2f} dB')
    ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('em_reservoir_training.png', dpi=150)
    plt.show()

## Comparison: Random vs EM-Trained Reservoir

In [None]:
# Compare with random reservoir (standard approach)
def random_reservoir_baseline(X, Y, hidden_size=256, num_layers=3):
    """Standard random reservoir with ridge regression."""
    np.random.seed(42)
    n, d = X.shape
    
    all_h = []
    layer_input = X
    
    for l in range(num_layers):
        d_in = layer_input.shape[1]
        W_in = np.random.randn(d_in, hidden_size) * 0.5
        W_hh = np.random.randn(hidden_size, hidden_size)
        eig = np.abs(np.linalg.eigvals(W_hh)).max()
        W_hh = W_hh * (0.9 / eig)
        b = np.random.randn(hidden_size) * 0.1
        
        h = np.zeros((n, hidden_size))
        for _ in range(10):
            h = np.tanh(layer_input @ W_in + h @ W_hh + b)
        
        all_h.append(h)
        layer_input = h
    
    H = np.hstack(all_h)
    W_out = np.linalg.solve(H.T @ H + 1e-6 * np.eye(H.shape[1]), H.T @ Y)
    pred = np.clip(H @ W_out, 0, 1)
    
    mse = np.mean((pred - Y) ** 2)
    psnr = -10 * np.log10(mse)
    return pred, psnr

# Fourier baseline
def fourier_baseline(X, Y, num_features=256, sigma=10):
    np.random.seed(42)
    B = np.random.randn(num_features, 2) * sigma
    H = np.concatenate([np.sin(2*np.pi*X@B.T), np.cos(2*np.pi*X@B.T)], axis=1)
    W = np.linalg.solve(H.T @ H + 1e-6 * np.eye(H.shape[1]), H.T @ Y)
    pred = np.clip(H @ W, 0, 1)
    mse = np.mean((pred - Y) ** 2)
    psnr = -10 * np.log10(mse)
    return pred, psnr

# Run baselines
pred_random, psnr_random = random_reservoir_baseline(X, Y, hidden_size=256, num_layers=3)
pred_fourier, psnr_fourier = fourier_baseline(X, Y, num_features=256, sigma=10)

print("\n" + "=" * 60)
print("COMPARISON")
print("=" * 60)
print(f"Random Reservoir (fixed):    {psnr_random:.2f} dB")
print(f"EM-Trained Reservoir:        {history['psnr'][-1]:.2f} dB")
print(f"Fourier Features:            {psnr_fourier:.2f} dB")
print(f"\nImprovement (EM vs Random):  {history['psnr'][-1] - psnr_random:+.2f} dB")

In [None]:
# Visual comparison
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(img_array)
axes[0].set_title('Target', fontsize=12)
axes[0].axis('off')

axes[1].imshow(pred_random.reshape(h, w, 3))
axes[1].set_title(f'Random Reservoir\n{psnr_random:.2f} dB', fontsize=12)
axes[1].axis('off')

if USE_TORCH:
    with torch.no_grad():
        X_t = torch.FloatTensor(X).to(device)
        pred_em = model.predict(X_t).cpu().numpy()
    axes[2].imshow(pred_em.reshape(h, w, 3))
    axes[2].set_title(f'EM-Trained Reservoir\n{history["psnr"][-1]:.2f} dB', fontsize=12)
else:
    axes[2].set_title('PyTorch required')
axes[2].axis('off')

axes[3].imshow(pred_fourier.reshape(h, w, 3))
axes[3].set_title(f'Fourier Features\n{psnr_fourier:.2f} dB', fontsize=12)
axes[3].axis('off')

plt.tight_layout()
plt.savefig('em_reservoir_comparison.png', dpi=150)
plt.show()

## Analysis: What Does EM Training Learn?

In [None]:
if USE_TORCH:
    # Analyze learned vs random reservoir weights
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Compare W_hh spectral properties
    for l in range(min(3, model.num_layers)):
        # Learned W_hh
        W_hh_learned = model.W_hh[l].detach().cpu().numpy()
        eigs_learned = np.linalg.eigvals(W_hh_learned)
        
        # Random W_hh (for comparison)
        np.random.seed(42 + l)
        W_hh_random = np.random.randn(256, 256)
        eig_max = np.abs(np.linalg.eigvals(W_hh_random)).max()
        W_hh_random = W_hh_random * (0.9 / eig_max)
        eigs_random = np.linalg.eigvals(W_hh_random)
        
        ax = axes[0, l]
        ax.scatter(eigs_random.real, eigs_random.imag, alpha=0.5, s=10, label='Random')
        ax.scatter(eigs_learned.real, eigs_learned.imag, alpha=0.5, s=10, label='Learned')
        circle = plt.Circle((0, 0), 1, fill=False, color='r', linestyle='--')
        ax.add_patch(circle)
        ax.set_xlim(-1.5, 1.5)
        ax.set_ylim(-1.5, 1.5)
        ax.set_aspect('equal')
        ax.set_title(f'Layer {l+1}: W_hh Eigenvalues')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Weight distribution comparison
    for l in range(min(3, model.num_layers)):
        W_in_learned = model.W_in[l].detach().cpu().numpy().flatten()
        
        ax = axes[1, l]
        ax.hist(W_in_learned, bins=50, alpha=0.7, density=True, label='Learned W_in')
        
        # Random for comparison
        np.random.seed(42 + l)
        W_in_random = (np.random.randn(2 if l == 0 else 256, 256) * 0.5).flatten()
        ax.hist(W_in_random, bins=50, alpha=0.5, density=True, label='Random W_in')
        
        ax.set_title(f'Layer {l+1}: W_in Distribution')
        ax.legend(fontsize=8)
    
    plt.tight_layout()
    plt.savefig('em_weight_analysis.png', dpi=150)
    plt.show()

## Key Insights

In [None]:
print("""
══════════════════════════════════════════════════════════════════════
                    KEY INSIGHTS: EM-TRAINED RESERVOIR
══════════════════════════════════════════════════════════════════════

1. EM ALGORITHM FOR RESERVOIR
   ━━━━━━━━━━━━━━━━━━━━━━━━━━
   E-step: Fix reservoir → optimal readout (closed-form ridge)
   M-step: Fix readout → update reservoir (gradient descent)
   
   This decomposes the joint optimization into tractable subproblems.

2. EVOLVING RESERVOIR
   ━━━━━━━━━━━━━━━━━━
   Traditional RC: Random fixed reservoir (never adapts)
   EM-RC: Reservoir evolves to better represent the data
   
   The reservoir learns to generate features that are more
   linearly separable for the readout.

3. WHEN DOES EM HELP?
   ━━━━━━━━━━━━━━━━━━
   ✓ When random initialization is suboptimal
   ✓ When the task requires specific feature structure
   ✓ When you can afford the extra computation
   
   ✗ May not help if random reservoir is already good
   ✗ More expensive than standard RC

4. CONNECTION TO OTHER METHODS
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━
   - Alternating minimization
   - Coordinate descent
   - Two-stage training (pretrain encoder, train decoder)
   - Self-training / iterative refinement

5. SPECTRAL CHANGES
   ━━━━━━━━━━━━━━━━━
   EM training can change the spectral properties of W_hh.
   The learned reservoir may have different dynamics than
   the random initialization.

══════════════════════════════════════════════════════════════════════
""")