In [None]:
import torch
import numpy as np

from src.diffusion import DiffusionSchedule
from src.sampler import DiffusionSampler
from src.geom import ca_bond_lengths

scale_factor = 10.0
L = 32
bond = 3.8 / scale_factor  # model units

# Toy x0: straight chain with perfect CA spacing
x0 = torch.zeros(1, L, 3)
x0[0, :, 0] = torch.arange(L) * bond

schedule = DiffusionSchedule(T=1000).to(torch.device("cpu"))

t_start = 200
t = torch.tensor([t_start])

# Fixed noise for determinism
g = torch.Generator().manual_seed(0)
noise = torch.randn(x0.shape, generator=g)

x_t, _ = schedule.q_sample(x0, t, noise=noise)

class Oracle(torch.nn.Module):
    def __init__(self, x0, schedule):
        super().__init__()
        self.register_buffer("x0", x0)
        self.schedule = schedule

    def forward(self, x_t, t, mask=None, x0_self_cond=None):
        if isinstance(t, int):
            t = torch.tensor([t], device=x_t.device)
        if t.dim() == 0:
            t = t.unsqueeze(0)

        B = x_t.shape[0]
        x0 = self.x0.expand(B, -1, -1)  # broadcast
        sqrt_ab = self.schedule.sqrt_alpha_bars[t].view(B, 1, 1)
        sqrt_omb = self.schedule.sqrt_one_minus_alpha_bars[t].view(B, 1, 1)
        eps = (x_t - sqrt_ab * x0) / sqrt_omb
        return eps

oracle = Oracle(x0, schedule)
sampler = DiffusionSampler(oracle, schedule)

x_recon = sampler.sample_from(
    x_t, start_t=t_start, verbose=False, add_noise=False, use_self_cond=False
)

rmse = (x_recon - x0).pow(2).mean().sqrt().item()
print("recon RMSE (model units):", rmse)

bonds_A = ca_bond_lengths(x_recon.squeeze(0).numpy() * scale_factor)
print("bond mean/min/max (Ã…):", bonds_A.mean(), bonds_A.min(), bonds_A.max())
