In [1]:
import os
import sys
import torch

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

%matplotlib inline

In [2]:
from graph_bridges.configs.graphs.config_sb import SBConfig
from graph_bridges.configs.graphs.config_sb import TrainerConfig, ParametrizedSamplerConfig
from graph_bridges.data.graph_dataloaders_config import EgoConfig,CommunityConfig,CommunitySmallConfig
from graph_bridges.models.backward_rates.backward_rate_config import BackRateMLPConfig
from graph_bridges.configs.graphs.config_sb import SBConfig, ParametrizedSamplerConfig, SteinSpinEstimatorConfig
from graph_bridges.models.backward_rates.backward_rate_config import GaussianTargetRateImageX0PredEMAConfig

In [3]:
from graph_bridges.models.generative_models.sb import SB
from graph_bridges.models.metrics.sb_metrics import paths_marginal_histograms

In [4]:
config = SBConfig(delete=True)

#config.data = EgoConfig(as_image=False, batch_size=32, full_adjacency=False)
#config.data = CommunityConfig(as_image=False, batch_size=32, full_adjacency=False)
config.data = EgoConfig(as_image=False, batch_size=32, full_adjacency=False)
#config.data = CommunitySmallConfig(as_image=False, batch_size=32, full_adjacency=False)
config.model = GaussianTargetRateImageX0PredEMAConfig(time_embed_dim=12, fix_logistic=False)

#config.model = BackRateMLPConfig(time_embed_dim=14,hidden_layer=150)
config.stein = SteinSpinEstimatorConfig(stein_sample_size=100)
config.sampler = ParametrizedSamplerConfig(num_steps=10,step_type="Poisson")
config.optimizer = TrainerConfig(learning_rate=1e-3,
                                 num_epochs=200,
                                 save_metric_epochs=20,
                                 metrics=["graphs_plots",
                                          "histograms"])
config.align_configurations()
#read the model
device = torch.device("cpu")
sb = SB(config, device)

In [5]:
databatch = next(sb.data_dataloader.train().__iter__())
x_spins_data = databatch[0]

In [8]:
print("From Dataloader full path in image shape with times")
spins_path_1, times_1 = sb.pipeline(None, 0, device, return_path=True,return_path_shape=True)

print("From given start")
spins_path_2, times_2 = sb.pipeline(sb.training_model,1,device,x_spins_data,return_path=True,return_path_shape=True)

From Dataloader full path in image shape with times


10it [00:00, 1431.41it/s]


From given start


10it [00:00, 80.73it/s]


In [7]:
backward_histogram,forward_histogram, times_1, times_2 = paths_marginal_histograms(sb=sb,
                                                                                   sinkhorn_iteration=0,
                                                                                   device=device,
                                                                                   current_model=sb.training_model,
                                                                                   past_to_train_model=None)

10it [00:00, 926.59it/s]
10it [00:00, 72.33it/s]
10it [00:00, 990.41it/s]
10it [00:00, 76.68it/s]
10it [00:00, 900.32it/s]
10it [00:00, 73.66it/s]
10it [00:00, 1089.51it/s]
10it [00:00, 73.65it/s]
10it [00:00, 945.54it/s]
10it [00:00, 81.66it/s]

Past Model 0
tensor([99., 78., 82., 45., 29., 26., 16., 12.,  5.,  8.,  6.,  4.,  3.,  4.,
         3.,  0.,  0., 72., 72., 53., 28., 24., 17., 13.,  5.,  3.,  3.,  3.,
         2.,  4.,  1.,  0.,  0., 75., 39., 22., 22., 16.,  9.,  3.,  4.,  5.,
         4.,  2.,  3.,  1.,  0.,  0., 49., 31., 22., 10., 13.,  4.,  1.,  4.,
         2.,  3.,  3.,  1.,  0.,  0., 27., 20., 12., 11.,  5.,  2.,  3.,  2.,
         1.,  2.,  2.,  0.,  0., 17., 12., 11.,  4.,  5.,  4.,  3.,  0.,  2.,
         3.,  0.,  0., 13., 12.,  6.,  4.,  4.,  1.,  2.,  2.,  1.,  0.,  0.,
        13.,  8.,  3.,  6.,  3.,  0.,  2.,  1.,  0.,  0., 10.,  6.,  5.,  0.,
         0.,  2.,  2.,  0.,  0.,  1.,  4.,  2.,  2.,  2.,  1.,  0.,  0.,  5.,
         3.,  1.,  3.,  1.,  0.,  0.,  2.,  1.,  3.,  3.,  0.,  0.,  1.,  2.,
         1.,  0.,  0.,  3.,  1.,  0.,  0.,  3.,  0.,  0.,  0.,  0.,  0.])
Past Model 1
tensor([83., 86., 86., 80., 89., 82., 70., 85., 92., 82., 88., 73., 75., 85.,
        73., 77., 72., 81., 91., 85., 80.,




From given start


10it [00:00, 62.73it/s]


In [64]:
x_spins_data.sum()

tensor(-4356.)

In [65]:
spins_path_2[:,0,:].sum()

tensor(-4356.)

In [45]:
times_2[0]

tensor([1.0000, 0.8900, 0.7800, 0.6700, 0.5600, 0.4500, 0.3400, 0.2300, 0.1200,
        0.0100, 0.0000], dtype=torch.float64)