In [None]:
import torch
import numpy as np
import random

def reverse_step_gradients_corrected(s_v, s_h, x_v_next, x_h_next):
    """
    Compute gradients according to equations (10-11) in the paper.
    """
    # Visible–hidden couplings (Nv x Nh)
    # First term: s_v[i] * x_h_next[j]
    # Second term: s_h[j] * x_v_next[i]
    grad_Jvh = torch.outer(s_v, x_h_next) + torch.outer(x_v_next, s_h)
    
    # Hidden–hidden couplings (Nh x Nh) - symmetric
    outer_prod = torch.outer(s_h, x_h_next)
    grad_Jhh = 0.5 * (outer_prod + outer_prod.T)
    
    # Hidden biases 
    grad_bh = s_h.clone()
    
    return grad_Jvh, grad_Jhh, grad_bh


def train_model_improved(
    mnist_dataset,
    J_vh, J_hh, b_h,
    num_epochs=5,
    digits_per_epoch=10,
    samples_per_digit=3,
    dt=1e-3,
    tf_train=2.5,
    lr=1e-2,
    device='cpu'
):
    """
    Improved training function following the paper more closely.
    """
    # Use multiple different digits as in the paper
    training_digits = [0, 1, 7]  # Paper uses 3 digits
    
    losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        gradient_count = 0
        
        print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
        
        # Accumulate gradients over multiple trajectories
        grad_Jvh_epoch = torch.zeros_like(J_vh)
        grad_Jhh_epoch = torch.zeros_like(J_hh)
        grad_bh_epoch = torch.zeros_like(b_h)
        
        for _ in range(digits_per_epoch):
            # Sample a random digit from training set
            digit = random.choice(training_digits)
            P = sample_digit_from_mnist(mnist_dataset, digit, device)
            
            for _ in range(samples_per_digit):
                # Run noising trajectory
                states_all, _ = run_noising_trajectory_improved(
                    P, J_vh, J_hh, b_h, U, tf=tf_train, dt=dt, 
                    t_eq=0.5, bias_scale=20.0, device=device
                )
                
                # Compute gradients over trajectory
                num_steps = len(states_all) - 1
                for (x_v, x_h), (x_v_next, x_h_next) in zip(states_all[:-1], states_all[1:]):
                    nll, (s_v, s_h), _ = reverse_step_nll(
                        x_v, x_h, x_v_next, x_h_next,
                        J_vh, J_hh, b_h, dt, mu=1.0, kBT=1.0
                    )
                    
                    dJvh, dJhh, dbh = reverse_step_gradients_corrected(s_v, s_h, x_v_next, x_h_next)
                    
                    grad_Jvh_epoch += dJvh
                    grad_Jhh_epoch += dJhh
                    grad_bh_epoch += dbh
                    gradient_count += 1
                    
                    epoch_loss += nll.item()
        
        # Average gradients
        grad_Jvh_epoch /= gradient_count
        grad_Jhh_epoch /= gradient_count
        grad_bh_epoch /= gradient_count
        
        # Update parameters
        J_vh -= lr * grad_Jvh_epoch
        J_hh -= lr * grad_Jhh_epoch
        b_h -= lr * grad_bh_epoch
        
        # Optionally add gradient clipping
        max_grad_norm = 10.0
        torch.nn.utils.clip_grad_norm_([J_vh, J_hh, b_h], max_grad_norm)
        
        avg_loss = epoch_loss / gradient_count
        losses.append(avg_loss)
        
        print(f"Epoch {epoch+1} | Avg NLL: {avg_loss:.4f}")
        print(f"Parameter norms - J_vh: {J_vh.norm():.4f}, J_hh: {J_hh.norm():.4f}, b_h: {b_h.norm():.4f}")
    
    return J_vh, J_hh, b_h, losses


def run_noising_trajectory_improved(
    P, J_vh, J_hh, b_h, U,
    tf=2.5, dt=1e-3, t_eq=0.5,
    bias_scale=20.0, device='cpu',
    use_trainable=False
):
    """
    Improved noising trajectory with proper coupling handling.
    """
    K = int(tf / dt)
    K_eq = int(t_eq / dt)
    
    # Initialize with noise
    x_v = torch.randn(P.shape[0], device=device)
    x_h = torch.randn(U.shape[0], device=device)
    
    # During noising, we DON'T use trainable couplings
    if use_trainable:
        Jvh_use, Jhh_use, bh_use = J_vh, J_hh, b_h
    else:
        Jvh_use = torch.zeros_like(J_vh)
        Jhh_use = torch.zeros_like(J_hh)
        bh_use = torch.zeros_like(b_h)
    
    states_all = []
    
    # Equilibration phase with full bias
    for _ in range(K_eq):
        b_v = P * bias_scale
        b_h_eq = (U @ P) * bias_scale
        x_v, x_h, _ = euler_maruyama_step(
            x_v, x_h, b_v, b_h_eq, Jvh_use, Jhh_use, dt
        )
    
    states_all.append((x_v.clone(), x_h.clone()))
    
    # Noising phase with fading bias
    for k in range(K):
        fade = 1.0 - (k / K)  # Linear fade from 1 to 0
        b_v = P * bias_scale * fade
        b_h_noise = (U @ P) * bias_scale * fade
        
        x_v, x_h, _ = euler_maruyama_step(
            x_v, x_h, b_v, b_h_noise, Jvh_use, Jhh_use, dt
        )
        
        states_all.append((x_v.clone(), x_h.clone()))
    
    return states_all, None


def sample_digit_from_mnist(mnist, digit, device='cpu'):
    """Helper to sample a specific digit from MNIST dataset."""
    idxs = [i for i, (_, y) in enumerate(mnist) if y == digit]
    i = random.choice(idxs)     
    v, _ = mnist[i]
    return v.view(-1).to(device)


def euler_maruyama_step(x_v, x_h, b_v, b_h, J_vh, J_hh, dt, kBT=1.0, mu=1.0):
    """Euler-Maruyama integration step (unchanged from original)."""
    # Compute energy and gradients
    V, g_v, g_h = energy_and_grad(x_v, x_h, b_v, b_h, J_vh, J_hh)
    
    # Gaussian noise
    noise_v = torch.randn_like(x_v)
    noise_h = torch.randn_like(x_h)
    
    sigma = torch.sqrt(torch.tensor(2.0 * mu * kBT * dt, device=x_v.device, dtype=x_v.dtype))
    
    # Update positions
    x_vn = x_v - mu * g_v * dt + sigma * noise_v
    x_hn = x_h - mu * g_h * dt + sigma * noise_h
    
    return x_vn, x_hn, V


def energy_and_grad(x_v, x_h, b_v, b_h, J_vh, J_hh, J2=10.0, J4=10.0):
    """Compute energy and its gradient (unchanged from original)."""
    b_v = b_v.to(x_v.device)
    b_h = b_h.to(x_h.device)
    
    x_v = x_v.clone().detach().requires_grad_(True)
    x_h = x_h.clone().detach().requires_grad_(True)
    
    # Energy terms
    V = J2*(x_v.pow(2).sum() + x_h.pow(2).sum()) \
      + J4*(x_v.pow(4).sum() + x_h.pow(4).sum()) 
    V = V + (b_v @ x_v) + (b_h @ x_h) 
    V = V + (x_v @ J_vh @ x_h) + 0.5*(x_h @ J_hh @ x_h)
    
    V.backward()
    
    g_v = x_v.grad.detach()
    g_h = x_h.grad.detach()
    
    return V.item(), g_v, g_h


def reverse_step_nll(
    x_v, x_h, x_v_next, x_h_next,
    J_vh, J_hh, b_h, dt,
    mu=1.0, kBT=1.0, J2=10.0, J4=10.0
):
    """Compute negative log-likelihood of reverse step (unchanged)."""
    b_v_trainable = torch.zeros_like(x_v_next)
    _, g_v_next, g_h_next = energy_and_grad(
        x_v_next, x_h_next, b_v_trainable, b_h, J_vh, J_hh, J2, J4
    )
    
    dx_v = x_v_next - x_v
    dx_h = x_h_next - x_h
    
    r_v = -dx_v + mu * g_v_next * dt
    r_h = -dx_h + mu * g_h_next * dt
    
    denominator = 4.0 * mu * kBT * dt
    nll = (r_v.pow(2).sum() + r_h.pow(2).sum()) / denominator
    
    s_v = r_v / (2.0 * kBT)
    s_h = r_h / (2.0 * kBT)
    
    return nll, (s_v, s_h), (x_v_next, x_h_next)
