In [1]:
import os
os.chdir('..')

# ATTRACTIVE_FLAG=True

In [2]:
import torch
from torch import autograd
import pickle
import wandb
import random

from nn_models import lbp_message_passing_network, GIN_Network_withEdgeFeatures
from ising_model.pytorch_dataset import build_factorgraph_from_SpinGlassModel
from ising_model.spin_glass_model import SpinGlassModel
from factor_graph import FactorGraphData
from factor_graph import DataLoader_custom as DataLoader_pytorchGeometric

from ising_model.pytorch_geometric_data import spinGlass_to_torchGeometric


import os
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import parameters
from parameters import ROOT_DIR, alpha, alpha2, SHARE_WEIGHTS, BETHE_MLP, NUM_MLPS
import cProfile

MODEL_MAP_FLAG = False
CLASSIFICATION_FLAG = True
TRAINING_FLAG = False
BETHE_MLP = False
ATTRACTIVE_FIELD = True
LEARNING_RATE = 1e-6
LR_DECAY_FLAG = False
ALPHA = alpha
ALPHA2 = alpha2
MSG_PASSING_ITERS = 50
SHARE_WEIGHTS = False


MODE = "train" #run "test" or "train" mode

TEST_TRAINED_MODEL = True 
EXPERIMENT_NAME = 'trained_MAP_attrField_10layer_2MLPs_noFinalBetheMLP/' #used for saving results when MODE='test'

USE_WANDB = False
####### Training PARAMETERS #######
MAX_FACTOR_STATE_DIMENSIONS = 2
EPSILON = 0 #set factor states with potential 0 to EPSILON for numerical stability
SHARE_WEIGHTS = True if not TRAINING_FLAG  else SHARE_WEIGHTS
MODEL_NAME = "MAP_spinGlass_%dlayer_alpha=%f.pth" % (MSG_PASSING_ITERS, parameters.alpha)
TRAINED_MODELS_DIR = ROOT_DIR + "trained_models_map/" #trained models are stored here

##########################################################################################################
N_MIN_TRAIN = 3
N_MAX_TRAIN = 3
F_MAX_TRAIN = .1
C_MAX_TRAIN = 5.0
ATTRACTIVE_FIELD_TRAIN = ATTRACTIVE_FIELD

N_MIN_VAL = 3
N_MAX_VAL = 3
F_MAX_VAL = .1
C_MAX_VAL = 5.0
ATTRACTIVE_FIELD_VAL = ATTRACTIVE_FIELD

REGENERATE_DATA = False
DATA_DIR = "./data/spin_glass_map/"

TRAINING_DATA_SIZE = 50
VAL_DATA_SIZE = 50#100
TEST_DATA_SIZE = 200

TRAIN_BATCH_SIZE=50
VAL_BATCH_SIZE=50

EPOCH_COUNT = 5000 if TRAINING_FLAG else 5
PRINT_FREQUENCY = 10 if TRAINING_FLAG else 1
VAL_FREQUENCY = 10 if TRAINING_FLAG else 1
SAVE_FREQUENCY = 100 if TRAINING_FLAG else 1

TEST_DATSET = 'val' 

##### Optimizer parameters #####
STEP_SIZE=(EPOCH_COUNT//4)
LR_DECAY=.5
if ATTRACTIVE_FIELD_TRAIN == True:
        LEARNING_RATE = LEARNING_RATE
else:
    LEARNING_RATE = LEARNING_RATE

def get_dataset(dataset_type):
    '''
    Store/load a list of SpinGlassModels
    When using, convert to BPNN or GNN form with either
    build_factorgraph_from_SpinGlassModel(pytorch_geometric=True) for BPNN or spinGlass_to_torchGeometric() for GNN
    '''
    assert(dataset_type in ['train', 'val', 'test'])
    if dataset_type == 'train':
        datasize = TRAINING_DATA_SIZE
        ATTRACTIVE_FIELD = ATTRACTIVE_FIELD_TRAIN
        N_MIN = N_MIN_TRAIN
        N_MAX = N_MAX_TRAIN
        F_MAX = F_MAX_TRAIN
        C_MAX = C_MAX_TRAIN
    elif dataset_type == 'val':
        datasize = VAL_DATA_SIZE
        ATTRACTIVE_FIELD = ATTRACTIVE_FIELD_VAL
        N_MIN = N_MIN_VAL
        N_MAX = N_MAX_VAL
        F_MAX = F_MAX_VAL
        C_MAX = C_MAX_VAL
    else:
        datasize = TEST_DATA_SIZE
        ATTRACTIVE_FIELD = ATTRACTIVE_FIELD_TEST

    dataset_file = DATA_DIR + dataset_type + '%d_%d_%d_%.2f_%.2f_attField=%s.pkl' % (datasize, N_MIN, N_MAX, F_MAX, C_MAX, ATTRACTIVE_FIELD)
    if REGENERATE_DATA or (not os.path.exists(dataset_file)):
        print("REGENERATING DATA!!")
        spin_glass_models_list = [SpinGlassModel(N=random.randint(N_MIN, N_MAX),\
                                                f=np.random.uniform(low=0, high=F_MAX),\
                                                c=np.random.uniform(low=0, high=C_MAX),\
                                                attractive_field=ATTRACTIVE_FIELD) for i in range(datasize)]
        if not os.path.exists(DATA_DIR):
            os.makedirs(DATA_DIR)
        with open(dataset_file, 'wb') as f:
            pickle.dump(spin_glass_models_list, f)
    else:
        with open(dataset_file, 'rb') as f:
            spin_glass_models_list = pickle.load(f)
    return spin_glass_models_list


spin_glass_models_list_train = get_dataset(dataset_type='train')

In [28]:
# Generate Factor Beliefs
from ising_model.libdai_utils import *
def bp_factor_marginals(sg_model, maxiter=None, updates="SEQRND", damping=None, map_flag=False):
    if maxiter is None:
        maxiter=LIBDAI_LBP_ITERS
    N = sg_model.lcl_fld_params.shape[0]

    # Set some constants
    maxiter = maxiter
    tol = 1e-9
    verb = 1
    # Store the constants in a PropertySet object
    opts = dai.PropertySet()
    opts["maxiter"] = str(maxiter)   # Maximum number of iterations
    opts["tol"] = str(tol)           # Tolerance for convergence
    opts["verbose"] = str(verb)      # Verbosity (amount of output generated)bpopts["updates"] = "SEQRND"
    opts["updates"] = updates
    opts["logdomain"] = "1"
    if damping is not None:
        opts["damping"] = str(damping)
    opts['inference'] = ('MAXPROD' if map_flag else 'SUMPROD')

    sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel(sg_model, fixed_variables={})
    bp = dai.BP( sg_FactorGraph, opts )
    bp.init()
    bp.run()

    return np.array([
        [bp.beliefF(i)[j] for j in range(4)]
        for i in range(N*N, sg_FactorGraph.nrFactors())
    ])

def exact_factor_marginals(sg_model, verbose=False, map_flag=True, classification_flag=True):
    # Set some constants
    maxiter = 10000
    tol = 1e-9
    verb = 0
    # Store the constants in a PropertySet object
    opts = dai.PropertySet()
    opts["maxiter"] = str(maxiter)   # Maximum number of iterations
    opts["tol"] = str(tol)           # Tolerance for convergence
    opts["verbose"] = str(verb)      # Verbosity (amount of output generated)bpopts["updates"] = "SEQRND"
    opts["updates"] = "HUGIN"
    opts["inference"] = "SUMPROD"

    N = sg_model.lcl_fld_params.shape[0]
    log_marginals = np.zeros([N*(N-1)*2, 4])

    sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel(sg_model, fixed_variables={})
    jt = dai.JTree( sg_FactorGraph, opts )
    jt.init()
    jt.run()
    logZ = jt.logZ()
    for rol in range(N):
        for col in range(N-1):
            vi = rol*N+col
            for si, states in enumerate([[-1,-1],[-1,1],[1,-1],[1,1]]):
                s1, s2 = states
                sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel(sg_model, fixed_variables={vi:s1, vi+N:s2})
                jt = dai.JTree( sg_FactorGraph, opts )
                jt.init()
                jt.run()
                log_marginals[vi,si] = jt.logZ()
    for col in range(N):
        for rol in range(N-1):
            vi = rol*N+col
            for si, states in enumerate([[-1,-1],[-1,1],[1,-1],[1,1]]):
                s1, s2 = states
                sg_FactorGraph = build_libdaiFactorGraph_from_SpinGlassModel(sg_model, fixed_variables={vi:s1, vi+1:s2})
                jt = dai.JTree( sg_FactorGraph, opts )
                jt.init()
                jt.run()
                log_marginals[vi+N*(N-1),si] = jt.logZ()
    probability = np.exp(log_marginals-logZ)
    return probability

In [29]:
exact_beliefs = [exact_factor_marginals(sg) for sg in spin_glass_models_list_train]
bp_beliefs = [bp_factor_marginals(sg, damping=ALPHA) for sg in spin_glass_models_list_train]
print(exact_beliefs[0])
print(bp_beliefs[0])

[[0.25422042 0.25660698 0.2405871  0.24858549]
 [0.25242154 0.23753386 0.25228563 0.25775896]
 [0.00193236 0.00193236 0.00193236 0.00193236]
 [0.25966197 0.23514555 0.23200535 0.27318713]
 [0.25673077 0.24797641 0.24895803 0.24633479]
 [0.00193236 0.00193236 0.00193236 0.00193236]
 [0.26211883 0.24870858 0.22783657 0.26133602]
 [0.23900813 0.25094727 0.2482073  0.2618373 ]
 [0.24110007 0.24611536 0.25370746 0.25907712]
 [0.25512227 0.23968525 0.24958491 0.25560757]
 [0.26693995 0.23776723 0.23572481 0.25956802]
 [0.24716978 0.25549497 0.24449754 0.25283771]]
[[0.25421518 0.24059234 0.25661223 0.24858025]
 [0.26211819 0.22783721 0.24870922 0.26133538]
 [0.25241895 0.25228824 0.23753645 0.25775636]
 [0.23898718 0.24822825 0.25096822 0.26181635]
 [0.26415741 0.23850734 0.22305801 0.27427723]
 [0.25966134 0.23200598 0.23514618 0.2731865 ]
 [0.25511893 0.24958826 0.23968859 0.25560422]
 [0.25670846 0.24898034 0.24799872 0.24631247]
 [0.26693779 0.23572696 0.23776939 0.25956585]
 [0.25668013