In [None]:
META_DIR = "$HOME/deepsat/models/neuropol-18-09-18-002956"
MAX_VARIABLE_NUM = 10

In [None]:
import tensorflow as tf
import numpy as np
import os
import sys
sys.path.append('..')
import notebook_tools

In [None]:

sess=tf.Session()    
#First let's load meta graph and restore weights
meta_file = notebook_tools.get_most_fresh_meta(os.path.expandvars(META_DIR))
saver = tf.train.import_meta_graph(meta_file, clear_devices=True)
sess.run(tf.global_variables_initializer())
saver.restore(sess, tf.train.latest_checkpoint(os.path.dirname(meta_file)))

# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()


In [None]:
g_inputs = graph.get_tensor_by_name("inputs:0")
LEVEL = 20
BATCH_SIZE = 1
g_policy_probs = graph.get_tensor_by_name('policy_prob_{}:0'.format(LEVEL))
g_sat_probs = graph.get_tensor_by_name('sat_prob_{}:0'.format(LEVEL))

from dpll import DPLL
from cnf import CNF
from cnf_dataset import clauses_to_matrix

class GraphBasedDPLL(DPLL):
    def suggest(self, input_cnf: CNF):
        clause_num = len(input_cnf.clauses)
        var_num = max(input_cnf.vars)
        inputs = np.asarray([clauses_to_matrix(input_cnf.clauses, clause_num, var_num)] * BATCH_SIZE)
        
        policy_probs = sess.run(g_policy_probs, feed_dict={g_inputs: inputs})
                
        best_prob = 0.0
        best_svar = None
        for var in input_cnf.vars:
            for svar in [var, -var]:
                svar_prob = policy_probs[0][var-1][0 if svar > 0 else 1]
                if svar_prob > best_prob:
                    best_prob = svar_prob
                    best_svar = svar
        return best_svar


In [None]:
experiments = [
    (100, 2, 2, 3),
    (100, 2, 4, 20)
]

for var_num in range(4, MAX_VARIABLE_NUM + 1):
    experiments += [(100, 3, var_num, var_num * 5),
                    (100, 3, var_num, var_num * 10)]


all_stats = notebook_tools.execute_experiments(experiments, [GraphBasedDPLL])

In [None]:
all_stats

In [None]:
notebook_tools.summary_table_and_plots(all_stats)