In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RBMWavefunction(nn.Module):
    def __init__(self, n_visible, n_hidden):
        super().__init__()
        self.W = nn.Parameter(torch.empty(n_visible, n_hidden))
        nn.init.normal_(self.W, std=0.01)  # Smaller initial weights
        self.a = nn.Parameter(torch.zeros(n_visible))
        self.b = nn.Parameter(torch.zeros(n_hidden))
             # better init

    def forward(self, v):
        # Input: v - shape (batch_size, n_visible), values ±1
        linear_term = torch.matmul(v, self.a)
        activation = torch.matmul(v, self.W) + self.b
        hidden_term = torch.sum(torch.log(torch.cosh(activation + 1e-7)), dim=1)
        return linear_term + hidden_term

    def psi(self, v):
        log_psi = self.forward(v)
        # Subtract max for numerical stability
        return torch.exp(log_psi - log_psi.max())  # wavefunction amplitude


In [34]:
def ising_energy(configs):
    """
    Compute Ising energy for each spin configuration in the batch.
    E = -sum(s_i * s_{i+1}) with periodic boundary conditions
    Input: configs - shape (batch_size, n_spins)
    Output: energy - shape (batch_size,)
    """
    return -torch.sum(configs * torch.roll(configs, shifts=-1, dims=1), dim=1)


In [35]:
def metropolis_sampler(rbm, initial_config, n_steps=50):
    config = initial_config.clone().detach()

    for step in range(n_steps):
        proposal = config.clone()
        flip_indices = torch.randint(0, config.shape[1], (config.shape[0],))

        for i in range(config.shape[0]):
            proposal[i, flip_indices[i]] *= -1  # Flip one spin

        log_prob_old = rbm.forward(config)
        log_prob_new = rbm.forward(proposal)
        accept_ratio = torch.exp(log_prob_new - log_prob_old)

        rand = torch.rand(config.shape[0])
        accept = (rand < accept_ratio).to(torch.float32).unsqueeze(1)

        config = accept * proposal + (1 - accept) * config

    return config.to(torch.float32)


In [36]:
def train_rbm(rbm, n_epochs=500, batch_size=200, lr=1e-2, n_spins=10):
    optimizer = torch.optim.Adam(rbm.parameters(), lr=lr)
    config = (torch.randint(0, 2, (batch_size, n_spins)) * 2 - 1).float()

    for epoch in range(n_epochs):
        samples = metropolis_sampler(rbm, config, n_steps=40)
        config = samples.detach()

        log_psi = rbm.forward(samples)
        energies = ising_energy(samples)
        
        # Stable psi calculation
        max_log_psi = log_psi.max()
        psi_values = torch.exp(log_psi - max_log_psi)
        
        # Energy calculation with detached values
        with torch.no_grad():
            local_energies = energies * psi_values
        
        loss = (local_energies * psi_values).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(rbm.parameters(), 1.0)
        optimizer.step()

        if epoch % 50 == 0:
            print(f"[Epoch {epoch:03d}] Energy: {loss.item():.4f}")


In [37]:
n_spins = 10
n_hidden = 20

rbm = RBMWavefunction(n_visible=n_spins, n_hidden=n_hidden)
train_rbm(rbm, n_epochs=500, batch_size=200, lr=1e-2, n_spins=n_spins)


[Epoch 000] Energy: 0.3111
[Epoch 050] Energy: -0.4895
[Epoch 100] Energy: -0.5959
[Epoch 150] Energy: -0.4922
[Epoch 200] Energy: -0.5252
[Epoch 250] Energy: -0.6144
[Epoch 300] Energy: -0.6607
[Epoch 350] Energy: -0.7113
[Epoch 400] Energy: -0.6936
[Epoch 450] Energy: -0.8681


In [38]:
test_samples = metropolis_sampler(rbm, (torch.randint(0, 2, (100, n_spins)) * 2 - 1).float())

psi_values = rbm.psi(test_samples) 

In [39]:
def export_wavefunction_data(rbm, n_samples=1000, n_spins=10, export_path="wavefunction_data.csv"):
    init = torch.randint(0, 2, (n_samples, n_spins)) * 2 - 1
    samples = metropolis_sampler(rbm, init.float())
    
    with torch.no_grad():
        log_psi = rbm.forward(samples)
        max_log = log_psi.max()
        psi_sq = torch.exp(2*(log_psi - max_log)).numpy()  # Proper |ψ|² calculation
    
    samples_np = samples.numpy()
    df = pd.DataFrame(samples_np, columns=[f"s_{i}" for i in range(n_spins)])
    df["psi_squared"] = psi_sq
    df.to_csv(export_path, index=False)


In [40]:
# After training the RBM
export_wavefunction_data(rbm, n_samples=2000, n_spins=n_spins, export_path="rbm_wavefunction_data.csv")
