In [1]:
import os
import sys
import torch
from pathlib import Path

In [2]:
# DIRECTORIES
from graph_bridges import base_path
#results folder
from graph_bridges import results_path
#data folder 
from graph_bridges import data_path

In [3]:
# CONFIGURATIONS IMPORT
from graph_bridges.configs.graphs.config_sb import SBConfig,SBTrainerConfig
from graph_bridges.configs.graphs.config_sb import SteinSpinEstimatorConfig
from graph_bridges.configs.graphs.config_sb import ParametrizedSamplerConfig
from graph_bridges.configs.graphs.config_sb import get_sb_config_from_file

# DATA CONFIGS
from graph_bridges.data.graph_dataloaders_config import CommunitySmallConfig,EgoConfig
# BACKWARD RATES CONFIGS 
from graph_bridges.models.backward_rates.backward_rate_config import BackRateMLPConfig,GaussianTargetRateImageX0PredEMAConfig

In [4]:
# MODEL IMPORTS
from graph_bridges.models.generative_models.sb import SB
from graph_bridges.models.trainers.sb_training import SBTrainer

All the functionality of a schrodinger bridge is integrated in the class SB

#from graph_bridges.models.generative_models.sb import SB

The Schrodinger Bridge problem requieres essentially 3 elements for model building
    -P0, which is the sb.data_dataloader
    -P1 which is the sb.target_dataloader
    -The reference process which defines a measure Q:
     sb.reference_process

It requieres 2 models (architectures to be trained)
    -sb
    
Following the diffusers library architectures, it requieres a Scheduler and a Pipeline for Inference (Generation)

In [5]:
#===========================================
# MODEL SET UP
#===========================================
sb_config = SBConfig(delete=True,
                     experiment_name="graph",
                     experiment_type="sb",
                     experiment_indentifier="tutorial_sb_trainer_stein_epsilon_02")
"""
The results are stored in results_path/experiment_name/experiment_type/experiment_indentifier
if experiment_indentifier is set to None a timenumber (UNIXTIME) is created
"""

sb_config.data = EgoConfig(as_image=False,batch_size=5,full_adjacency=False)
sb_config.model = BackRateMLPConfig(time_embed_dim=14, hidden_layer=200)
sb_config.stein = SteinSpinEstimatorConfig(stein_sample_size=200,
                                          stein_epsilon=0.2)
sb_config.sampler = ParametrizedSamplerConfig(num_steps=5)
sb_config.trainer = SBTrainerConfig(learning_rate=1e-3,
                                    num_epochs=500,
                                    save_metric_epochs=50,
                                    save_model_epochs=50,
                                    save_image_epochs=50,
                                    device="cuda:0",
                                    metrics=["graphs_plots",
                                             "graphs",
                                             "histograms"])

In [6]:
#========================================
# TRAIN
#========================================
sb_trainer = SBTrainer(sb_config)
sb_trainer.train_schrodinger()

cuda:0


5it [00:01,  3.86it/s]
5it [00:00, 716.26it/s]
5it [00:00, 1002.94it/s]
5it [00:00, 716.04it/s]
5it [00:00, 986.11it/s]
5it [00:00, 770.50it/s]
5it [00:00, 516.76it/s]
5it [00:00, 661.98it/s]
5it [00:00, 830.88it/s]
5it [00:00, 705.26it/s]
5it [00:00, 829.77it/s]
5it [00:00, 714.39it/s]
5it [00:00, 478.89it/s]
5it [00:00, 368.74it/s]
5it [00:00, 806.81it/s]
5it [00:00, 647.45it/s]
5it [00:00, 834.92it/s]
5it [00:00, 705.78it/s]
5it [00:00, 834.65it/s]
5it [00:00, 704.29it/s]
5it [00:00, 825.65it/s]
5it [00:00, 623.15it/s]
5it [00:00, 834.36it/s]
5it [00:00, 665.89it/s]
5it [00:00, 832.67it/s]
5it [00:00, 617.37it/s]
5it [00:00, 825.36it/s]
5it [00:00, 674.80it/s]
5it [00:00, 834.82it/s]
5it [00:00, 618.17it/s]
5it [00:00, 834.06it/s]
5it [00:00, 617.97it/s]
5it [00:00, 822.35it/s]
5it [00:00, 620.06it/s]
5it [00:00, 824.19it/s]
5it [00:00, 621.47it/s]
5it [00:00, 716.14it/s]
5it [00:00, 626.67it/s]
5it [00:00, 716.26it/s]
5it [00:00, 626.78it/s]
5it [00:00, 835.52it/s]
5it [00:00, 716.

41 40
Time computing degree mmd:  0:00:06.039037
{'degree': 1.3922195295820237}
Time computing clustering mmd:  0:00:08.650237
{'degree': 1.3922195295820237, 'cluster': 1.2813378351761215}
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 application
[WinError 193] %1 is not a valid Win32 applica

5it [00:00, 288.10it/s]
5it [00:00, 357.98it/s]
5it [00:00, 385.67it/s]
5it [00:00, 358.12it/s]
  0%|                                                                                               | 0/500 [00:00<?, ?it/s]
5it [00:00, 624.60it/s]

5it [00:00, 712.93it/s]

5it [00:00, 716.34it/s]

5it [00:00, 357.31it/s]

5it [00:00, 332.88it/s]

5it [00:00, 717.37it/s]

5it [00:00, 832.40it/s]

5it [00:00, 712.76it/s]

5it [00:00, 712.88it/s]

5it [00:00, 454.36it/s]

5it [00:00, 835.69it/s]

5it [00:00, 831.11it/s]

5it [00:00, 454.45it/s]

5it [00:00, 716.31it/s]

5it [00:00, 716.24it/s]

5it [00:00, 835.75it/s]

5it [00:00, 416.63it/s]

5it [00:00, 716.51it/s]

5it [00:00, 708.47it/s]

5it [00:00, 776.09it/s]

5it [00:00, 716.07it/s]

5it [00:00, 718.65it/s]

5it [00:00, 712.88it/s]

5it [00:00, 715.85it/s]

5it [00:00, 713.49it/s]

5it [00:00, 714.09it/s]

5it [00:00, 712.90it/s]

5it [00:00, 712.86it/s]

5it [00:00, 716.22it/s]

5it [00:00, 716.00it/s]

5it [00:00, 701.06it/s]

5it 

Epoch: 1, Loss: -0.8188105380740387



5it [00:00, 716.00it/s]

5it [00:00, 713.51it/s]

5it [00:00, 386.50it/s]

5it [00:00, 719.29it/s]

5it [00:00, 773.12it/s]

5it [00:00, 703.41it/s]

5it [00:00, 662.08it/s]

5it [00:00, 713.10it/s]

5it [00:00, 626.48it/s]

5it [00:00, 662.80it/s]

5it [00:00, 554.80it/s]

5it [00:00, 624.43it/s]

5it [00:00, 627.40it/s]

5it [00:00, 713.54it/s]

5it [00:00, 624.08it/s]

5it [00:00, 712.76it/s]

5it [00:00, 624.13it/s]

5it [00:00, 416.66it/s]

5it [00:00, 416.60it/s]

5it [00:00, 713.63it/s]

5it [00:00, 626.65it/s]

5it [00:00, 626.86it/s]

5it [00:00, 712.90it/s]

5it [00:00, 712.35it/s]

5it [00:00, 708.88it/s]

5it [00:00, 664.29it/s]

5it [00:00, 717.37it/s]

5it [00:00, 629.36it/s]

5it [00:00, 712.44it/s]

5it [00:00, 712.61it/s]

5it [00:00, 623.84it/s]

5it [00:00, 713.78it/s]

5it [00:00, 712.90it/s]

5it [00:00, 712.76it/s]

5it [00:00, 712.95it/s]

5it [00:00, 716.46it/s]

5it [00:00, 624.15it/s]

5it [00:00, 659.01it/s]

5it [00:00, 716.41it/s]

5it [00:00, 715.53it/s]


In [None]:
sb_trained = SB()
sb_trained.load_from_results_folder(experiment_name="graph",
                                    experiment_type="sb",
                                    experiment_indentifier="tutorial_sb_trainer",
                                    sinkhorn_iteration_to_load=0)

In [None]:
sb_trained.past_model

In [None]:
x_end = sb_trained.pipeline(None, 
                            0, 
                            torch.device("cpu"), 
                            sample_size=32, 
                            return_path=False)

In [None]:
#the results are as spins
x_end.min(),x_end.max()

In [None]:
x_adj = sb_trained.data_dataloader.transform_to_graph(x_end)

In [None]:
x_end.shape

In [None]:
from