In [1]:
import torch
import unittest

from pprint import pprint


from graph_bridges.models.generative_models.sb import SB
from graph_bridges.configs.config_sb import SBTrainerConfig
from graph_bridges.data.graph_dataloaders_config import EgoConfig
from graph_bridges.data.spin_glass_dataloaders_config import ParametrizedSpinGlassHamiltonianConfig
from graph_bridges.configs.graphs.graph_config_sb import SBConfig
from graph_bridges.configs.config_sb import ParametrizedSamplerConfig, SteinSpinEstimatorConfig

from graph_bridges.models.temporal_networks.mlp.temporal_mlp import TemporalMLPConfig
from graph_bridges.models.backward_rates.ctdd_backward_rate_config import BackwardRateTemporalHollowTransformerConfig
from graph_bridges.models.temporal_networks.transformers.temporal_hollow_transformers import TemporalHollowTransformerConfig
from graph_bridges.models.trainers.sb_training import SBTrainer
from graph_bridges.models.spin_glass.spin_utils import copy_and_flip_spins

In [2]:
from graph_bridges.data.transforms import SpinsToBinaryTensor
spins_to_binary = SpinsToBinaryTensor()

In [89]:


sb_config = SBConfig(delete=True,
                     experiment_name="spin_glass",
                     experiment_type="sb",
                     experiment_indentifier=None)

batch_size = 2
sb_config.data = ParametrizedSpinGlassHamiltonianConfig(data="bernoulli_small",
                                                        batch_size=batch_size,
                                                        number_of_spins=3)

sb_config.target = ParametrizedSpinGlassHamiltonianConfig(data="bernoulli_small",
                                                          batch_size=batch_size,
                                                          number_of_spins=3)

sb_config.temp_network = TemporalMLPConfig(time_embed_dim=12,hidden_dim=250)

sb_config.flip_estimator = SteinSpinEstimatorConfig(stein_sample_size=2000, stein_epsilon=0.1)
sb_config.sampler = ParametrizedSamplerConfig(num_steps=20)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

sb = SB()
sb.create_new_from_config(sb_config,device)

batchdata = next(sb.data_dataloader.train().__iter__())

X_spins = batchdata[0].to(device)
X_copy_spin, X_flipped_spin = copy_and_flip_spins(X_spins)
current_time = torch.rand((batch_size)).to(device)
copy_time = torch.repeat_interleave(current_time,X_spins.size(1))

phi_0 = sb.reference_process
phi_1 = sb.training_model

flip_estimate_ = sb.backward_ratio_estimator.flip_estimator(phi_1, X_spins, current_time)
loss = sb.backward_ratio_estimator(phi_1,phi_0,X_spins,current_time)

Coupling from Spin File Different from Config
Coupling from Spin File Different from Config
Coupling from Spin File Different from Config
Coupling from Spin File Different from Config


In [55]:
phi_1(X_flipped_spin,copy_time)

tensor([[0.4866, 0.8786, 0.7124],
        [0.8444, 0.7118, 0.7421],
        [0.8912, 0.5958, 0.5714],
        [0.4868, 0.8751, 0.7111],
        [0.8448, 0.7087, 0.7407],
        [0.8917, 0.5931, 0.5703]], device='cuda:0',
       grad_fn=<SoftplusBackward0>)

In [56]:
flip_estimate_

tensor([[0.4801],
        [0.7153],
        [0.5512],
        [0.4804],
        [0.7122],
        [0.5502]], device='cuda:0', grad_fn=<ReshapeAliasBackward0>)

In [67]:
modified_states_tensor[:,1,:]

tensor([[-1.,  1.,  1.],
        [-1.,  1.,  1.]], device='cuda:0')

In [103]:
def flip_and_copy_spins(X_spins):
    batch_size = X_spins.size(0)
    number_of_spins = X_spins.size(1)
    flip_mask = torch.ones((number_of_spins,number_of_spins)).fill_diagonal_(-1.).to(X_spins.device)
    flip_mask = flip_mask.repeat((X_spins.size(0),1))
    X_copy = X_spins.repeat_interleave(number_of_spins,dim=0)
    X_flipped = X_copy*flip_mask
    return X_copy,X_flipped

In [121]:
batch_size = X_spins.size(0)
number_of_spins = X_spins.size(1)
X_copy,X_flipped = flip_and_copy_spins(X_spins)
copy_time = torch.repeat_interleave(current_time,X_spins.size(1))
transition_rates_ = phi_1(X_flipped,copy_time)
transition_rates = transition_rates.reshape(batch_size,number_of_spins,number_of_spins)
transition_rates = torch.einsum("bii->bi",transition_rates)

In [127]:
binary_ = torch.Tensor([0.,1.,1.,0.])
binary_

tensor([0., 1., 1., 0.])

In [128]:
(~binary_.bool()).float()

tensor([1., 0., 0., 1.])