In [1]:
from gen_catalyst_design.discrete_space_diffusion import AbsorbingStateNoiser, CosineScheduler, LinearScheduler, ExponentialScheduler
from gen_catalyst_design.discrete_space_diffusion.Dataset import get_dataloaders_from_atoms_list
from ase.io import read
import torch.nn.functional as F
import torch
from ase.io import write

In [2]:
use_absorbing_state = True
mask_classes = True
element_pool = ["Au","Cu","Pd","Rh","Ni","Ga"]
if use_absorbing_state:
    element_pool = ["(X)"] + element_pool

atoms_list = read("dataset.traj", index=":")
if mask_classes:
    for atoms in atoms_list:
        atoms.info["class"] = 0

train_loader, val_loader = get_dataloaders_from_atoms_list(
    atoms_list=atoms_list,
    element_pool=element_pool,
    batch_size=10,
    do_initial_shuffling=True
)

In [3]:
for batch in train_loader:
    print(type(batch))
    graphs = batch.to_data_list()
    atoms_list = [graph.to_atoms(element_pool) for graph in graphs]
    write(filename="batched.traj", images=atoms_list)
    break

<class 'abc.GraphBatch'>
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])
torch.Size([21, 3])


In [None]:
scheduler = CosineScheduler(beta_max=1.0, beta_min=1e-3)
noiser = AbsorbingStateNoiser(element_pool=element_pool)
noiser.pre_compute_accum_q_matrices(scheduler=scheduler)

In [None]:
def cross_entropy(p, q):
    mask_indices = q > 0.0
    #print(mask_indices)
    #print(mask_indices.shape)
    #print(p[mask_indices])
    #print(p[mask_indices]*torch.log(q[mask_indices]))
    return -(p[mask_indices]*torch.log(q[mask_indices])).sum()

def get_denoise_probs(
        denoise_logits,
        x_t,
        batch,
        time,
        scheduler
    ):
    denoise_probs = F.softmax(denoise_logits, dim=-1)
    x0s = [F.one_hot(torch.tensor(i), num_classes=len(element_pool))*torch.ones(size=(len(x_t), 1)) for i in range(len(element_pool))]
    #print(x0s[0])
    q_revs_tot = torch.stack([noiser.get_reverse_transition_probabilities(
        x0_batch=x0*1.0,
        x_t_batch=x_t*1.0, 
        time_batch=time[batch.batch], 
        scheduler=scheduler
    ) for x0 in x0s
    ])
    #print(q_revs_tot[1])
    #result_prob = 0.0
    #for q_rev in q_revs_tot:
    #    result_prob+= denoise_probs*q_rev
    #print()
    #print(q_revs_tot[0])
    #print(torch.argwhere(q_revs_tot[0] > 0.0))
    #print(q_revs_tot[3])
    #result_probs = torch.zeros_like(x_t)
    #_, categories = result_probs.shape
    #for category in categories:
    #    for q_rev in q_revs_tot:
    #        q_rev*denoise_probs
    #print(torch.argwhere(q_revs_tot > 0.0))
    
    #print(q_revs_tot[2])
    #print(denoise_probs[None, :, :])
    summed_probs = (denoise_probs[None, :, :] * q_revs_tot).sum(dim=0)
    #print(summed_probs)
    #norm_const = summed_probs.sum(dim=1, keepdim=True)
    #reg_indices = (norm_const > 0.0).reshape(shape=(len(summed_probs),))
    #normalized_probs = summed_probs[reg_indices]/norm_const[reg_indices]
    
    return summed_probs/summed_probs.sum(dim=1, keepdim=True)
    

In [None]:
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
for batch in train_loader:
    loss = 0.0
    time = scheduler.sample_time(n_samples=batch.batch_size, t_span=(800,scheduler.t_final))
    x_t = noiser.noise_x0_xt(x0_batch=batch.x*1.0, time_batch=time[batch.batch])
    q_revs_loss = noiser.get_reverse_transition_probabilities(
        x0_batch=batch.x*1.0,
        x_t_batch=x_t*1.0, 
        time_batch=time[batch.batch], 
        scheduler=scheduler
    )
    print(x_t[0])
    denoise_logits = torch.rand_like(x_t*1.0)
    normalized_probs = get_denoise_probs(
        denoise_logits=denoise_logits,
        x_t=x_t*1.0,
        batch=batch,
        time=time,
        scheduler=scheduler
    )
    print(normalized_probs[0])
    loss+=cross_entropy(p=q_revs_loss, q=normalized_probs)
    #x_1 = noiser.noise_x0_xt(
    #    x0_batch=batch.x*1.0, 
    #    time_batch=torch.ones(size=(batch.batch_size,), dtype=torch.long)[batch.batch]
    #)
    #denoise_logits = torch.rand_like(x_t*1.0)
    #q_forward = noiser.get_transition_probabilities(
    #    x_t_batch=x_1*1.0,
    #    time_batch=torch.ones(size=(batch.batch_size,), dtype=torch.long)[batch.batch],
    #    scheduler=scheduler
    #)
    #print(noiser.accumulated_q_matrices[0])
    #print(q_forward[0][0:2])
    #normalized_probs = get_denoise_probs(
    #    denoise_logits=denoise_logits,
    #    x_t=x_1*1.0,
    #    batch=batch,
    #    time=torch.ones(size=(batch.batch_size,), dtype=torch.long),
    #    scheduler=scheduler
    #)
    #print(x_1[8])
    #print(normalized_probs)
    #loss+=cross_entropy(p=q_forward, q=normalized_probs)
    #print(loss)
    break

In [None]:
x_1[8]

In [None]:
torch.sum(torch.tensor([0.0950, 0.2330, 0.1178, 0.1900, 0.1002, 0.1063, 0.1576]))