In [12]:
import torch
from torch.distributions import Normal,Categorical

#configs
from markov_bridges.configs.config_classes.data.molecules_configs import QM9Config
from markov_bridges.configs.config_classes.generative_models.cmb_config import CMBConfig
from markov_bridges.configs.config_classes.networks.mixed_networks_config import MixedEGNN_dynamics_QM9Config

#models
from markov_bridges.models.generative_models.cmb_lightning import MixedForwardMapL

# loaders
from markov_bridges.data.dataloaders_utils import get_dataloaders
from markov_bridges.models.networks.temporal.mixed.mixed_networks_utils import load_mixed_network
from markov_bridges.utils.paralellism import check_model_devices


from markov_bridges.utils.equivariant_diffusion import (
    assert_mean_zero_with_mask, 
    remove_mean_with_mask,
    check_mask_correct, 
    sample_center_gravity_zero_gaussian_with_mask,
    random_rotation,
)

from collections import namedtuple

In [2]:
config = CMBConfig()
config.data = QM9Config(num_pts_train=1000,
                        num_pts_test=200,
                        num_pts_valid=200)
config.mixed_network = MixedEGNN_dynamics_QM9Config(n_layers=1,
                                                    conditioning=['H_thermo', 'homo'])

In [3]:
dtype = torch.float32
dataloader = get_dataloaders(config)
model = MixedForwardMapL(config,dataloader,save=False)
databatch = dataloader.get_databatch()

x = databatch['positions'].to(dtype)
node_mask = databatch['atom_mask'].to(dtype).unsqueeze(2)
edge_mask = databatch['edge_mask'].to(dtype)
one_hot = databatch['one_hot'].to(dtype)
charges = databatch['charges'].to(x.device, dtype)

Conditioning on ['H_thermo', 'homo']
Entropy of n_nodes: H[N] -2.475700616836548


  probs = Categorical(torch.tensor(probs))


In [35]:
def augment_noise(x,one_hot,node_mask,charges,augment_noise=0.,data_augmentation=False):
    """
    """
    # add noise 
    x = remove_mean_with_mask(x, node_mask)
    if augment_noise > 0:
        # Add noise eps ~ N(0, augment_noise) around points.
        eps = sample_center_gravity_zero_gaussian_with_mask(x.size(), x.device, node_mask)
        x = x + eps * augment_noise
    x = remove_mean_with_mask(x, node_mask)
    if data_augmentation:
        x = random_rotation(x).detach()

    check_mask_correct([x, one_hot, charges], node_mask)
    assert_mean_zero_with_mask(x, node_mask)
    h = {'categorical': one_hot, 'integer': charges}
    return x, h

def cmb_source_and_nametuple(x,h,config):
    conditioning = config.mixed_network.conditioning
    basic_key_strings = "num_atoms source_discrete source_continuous target_discrete target_continuous"
    condition_key_strings = " ".join(conditioning)
    mask_key_strings = "atom_mask edge_mask time"

    all_key_strings = basic_key_strings+" "+condition_key_strings+" "+mask_key_strings
    DatabatchNametuple = namedtuple("DatabatchClass", all_key_strings)

    target_discrete = torch.argmax(h["categorical"],dim=2)
    vocab_size = config.data.vocab_size
    data_size = target_discrete.size(0)
    discrete_dimensions = target_discrete.size(1)
    target_continuous = x.reshape(data_size,-1)
    continuous_dimensions = target_continuous.size(1)

    #Discrete
    uniform_probability = torch.full((vocab_size,),1./vocab_size)
    source_discrete = Categorical(uniform_probability).sample((data_size,discrete_dimensions)).to(x.device)

    #Continuous
    gaussian_probability = Normal(0.,1.)
    source_continuous = gaussian_probability.sample((data_size,continuous_dimensions)).to(x.device)

    time = torch.rand((data_size,)).to(x.device)

    items_ = [databatch["num_atoms"],source_discrete,source_continuous,target_discrete,target_continuous]
    for prop in conditioning:
        items_.append(databatch[prop])
    items_.extend([databatch["atom_mask"],databatch["edge_mask"],time])

    databatch_nametuple = DatabatchNametuple(*items_)
    return databatch_nametuple

In [37]:
x,h = augment_noise(x,one_hot,node_mask,charges)
databatch_nametuple = cmb_source_and_nametuple(x,h,config)