In [9]:
from gen_catalyst_design.discrete_space_diffusion import AbsorbingStateNoiser, CosineScheduler
from gen_catalyst_design.discrete_space_diffusion.Dataset import get_dataloaders_from_atoms_list
from ase.io import read
import torch.nn.functional as F
import torch

In [10]:
use_absorbing_state = True
mask_classes = True
element_pool = ["Au","Cu","Pd","Rh","Ni","Ga"]
if use_absorbing_state:
    element_pool = ["(X)"] + element_pool

atoms_list = read("dataset.traj", index=":")
if mask_classes:
    for atoms in atoms_list:
        atoms.info["class"] = 0

train_loader, val_loader = get_dataloaders_from_atoms_list(
    atoms_list=atoms_list,
    element_pool=element_pool,
    batch_size=4,
)

In [11]:
scheduler = CosineScheduler(beta_max=1e-1, beta_min=1e-4)
noiser = AbsorbingStateNoiser(element_pool=element_pool)
noiser.pre_compute_accum_q_matrices(scheduler=scheduler)

In [51]:
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
for batch in train_loader:
    time = scheduler.sample_time(n_samples=batch.batch_size, t_span=(2,scheduler.t_final))
    x_t = noiser.noise_x0_xt(x0_batch=batch.x*1.0, time_batch=time[batch.batch])
    x0s = [F.one_hot(torch.tensor(i+1), num_classes=len(element_pool))*torch.ones(size=(len(x_t), 1)) for i in range(len(element_pool)-1)]
    q_revs_tot = torch.stack([noiser.get_reverse_transition_probabilities(
        x0_batch=x0*1.0,
        x_t_batch=x_t*1.0, 
        time_batch=time[batch.batch], 
        scheduler=scheduler
    ) for x0 in x0s
    ])
    denoise_probs = torch.rand_like(x_t*1.0)
    denoise_probs = denoise_probs/denoise_probs.sum(dim=1, keepdim=True)
    summed_probs = (denoise_probs[None, :, :] * q_revs_tot).sum(dim=0)
    print(q_revs_tot[0])
    break

tensor([[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 0.0042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 0.0042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 0.0042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 0.0042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 0.0042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 0.0042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 0.0042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [   nan,    nan,    nan,    nan,    nan,    nan,    nan],
        [ 

In [52]:
print(x_t[14])
print(x0s[0])

tensor([0, 0, 0, 1, 0, 0, 0])
tensor([[0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0.,

In [53]:
print(torch.argwhere(torch.isnan(noiser.accumulated_q_matrices) == True))

tensor([], size=(0, 3), dtype=torch.int64)


In [59]:
print(summed_probs[1])

tensor([2.0285e+00, 1.4161e-03, 1.8398e-04, 1.8024e-04, 1.5108e-04, 3.5765e-04,
        4.8327e-04])
