In [16]:
import tensorflow as tf
from tensorflow.contrib import predictor

export_dir = "model_sr50"
predict_fn = predictor.from_saved_model(export_dir)

INFO:tensorflow:Restoring parameters from model_sr50/variables/variables


In [17]:
from matplotlib import pyplot as plt
import numpy as np
import random

In [18]:
import sys 
sys.path.insert(0,'../..')

In [19]:
from cnf_dataset import clauses_to_matrix
from dpll import DPLL, RandomClauseDPLL, MostCommonVarDPLL, RandomVarDPLL
from cnf import get_random_kcnf, CNF, get_sats_SR, get_pos_SR
from tqdm import tqdm
from collections import Counter

In [20]:
import math
from collections import defaultdict

def jw(clauses):
    score = defaultdict(int)

    for clause in clauses:
        for l in clause:
            score[l] += 2. ** (-len(clause))

    return max(score, key=score.get)


In [21]:
LIMIT_RUNS = 1000000

In [22]:
def shorten_cnf(cnf: CNF):
    for c in cnf.clauses:
        if len(c) == 1:
            return shorten_cnf(cnf.set_var(c[0]))
    all_literals = set(x
                       for clause in cnf.clauses
                       for x in clause)
    for v in cnf.vars:
        if v in all_literals and (-v) not in all_literals:
            return shorten_cnf(cnf.set_var(v))
        if (-v) in all_literals and v not in all_literals:
            return shorten_cnf(cnf.set_var(-v))
    return cnf

def make_normalized(cls):
    class NormalizedDPLL(cls):
        def run(self, cnf: CNF):
            assert isinstance(cnf, CNF)
            self.number_of_runs += 1
            if self.number_of_runs > LIMIT_RUNS:
                return None
            
            cnf = shorten_cnf(cnf)
            if cnf.is_true():
                return []
            elif cnf.is_false():
                return None

            sug_var = self.suggest(cnf)
            sug_cnf = cnf.set_var(sug_var)

            sug_res = self.run(sug_cnf)
            if sug_res is not None:
                return [sug_var] + sug_res

            not_sug_cnf = cnf.set_var(-sug_var)
            not_sug_res = self.run(not_sug_cnf)
            if not_sug_res is not None:
                self.number_of_errors += 1
                return [-sug_var] + not_sug_res
            return None
    NormalizedDPLL.__name__ = "Normalized{}".format(cls.__name__)
    return NormalizedDPLL

In [23]:
np.set_printoptions(precision=3, suppress=True)

In [24]:
import tensorflow as tf
import os

BATCH_SIZE = 1

In [25]:

class GraphBasedDPLL(DPLL):
    def suggest(self, input_cnf: CNF):
        clause_num = len(input_cnf.clauses)
        var_num = max(input_cnf.vars)
        inputs = np.asarray([clauses_to_matrix(input_cnf.clauses, clause_num, var_num)] * BATCH_SIZE)
        
        policy_probs = predict_fn({"input": inputs})['policy_probabilities']
        
        best_prob = 0.0
        best_svar = None
        for var in input_cnf.vars:
            for svar in [var, -var]:
                svar_prob = policy_probs[0][var-1][0 if svar > 0 else 1]
                if svar_prob > best_prob:
                    best_prob = svar_prob
                    best_svar = svar
        return best_svar

class MostCommonDPLL(DPLL):
    def suggest(self, cnf: CNF):
        counter = Counter()
        for clause in cnf.clauses:
            for svar in clause:
                counter[svar] += 1
        return counter.most_common(1)[0][0]
    
class JeroslawDPLL(DPLL):
    def suggest(self, cnf: CNF):
        return jw(cnf.clauses)
    

class HumbleDPLL(DPLL):
    def suggest(self, input_cnf: CNF):
        clause_num = len(input_cnf.clauses)
        var_num = max(input_cnf.vars)
        inputs = np.asarray([clauses_to_matrix(input_cnf.clauses, clause_num, var_num)] * BATCH_SIZE)
        
        output = predict_fn({"input": inputs})
        policy_probs = output['policy_probabilities']
        sat_prob = output['sat_probabilities'][0]
        
        best_prob = 0.0
        best_svar = None
        for var in input_cnf.vars:
            for svar in [var, -var]:
                svar_prob = policy_probs[0][var-1][0 if svar > 0 else 1]
                if svar_prob > best_prob:
                    best_prob = svar_prob
                    best_svar = svar
        if sat_prob < 0.3:
            # Overwriting with JW
            best_svar = jw(input_cnf.clauses)
        return best_svar    

In [26]:
NormalizedGraphBasedDPLL = make_normalized(GraphBasedDPLL)
NormalizedMostCommonDPLL = make_normalized(MostCommonDPLL)
NormalizedJeroslawDPLL = make_normalized(JeroslawDPLL)
NormalizedHumbleDPLL = make_normalized(HumbleDPLL)

In [27]:
class FastHumbleDPLL(DPLL):
    def run(self, cnf: CNF, fast=False):
        assert isinstance(cnf, CNF)
        self.number_of_runs += 1
        if self.number_of_runs > LIMIT_RUNS:
            return None

        cnf = shorten_cnf(cnf)
        if cnf.is_true():
            return []
        elif cnf.is_false():
            return None

        switch_to_heuristic, sug_var = self.suggest(cnf, fast)
        sug_cnf = cnf.set_var(sug_var)

        sug_res = self.run(sug_cnf, switch_to_heuristic)
        if sug_res is not None:
            return [sug_var] + sug_res

        not_sug_cnf = cnf.set_var(-sug_var)
        not_sug_res = self.run(not_sug_cnf, switch_to_heuristic)
        if not_sug_res is not None:
            self.number_of_errors += 1
            return [-sug_var] + not_sug_res
        return None

    def suggest(self, input_cnf: CNF, fast):
        if not fast:
            clause_num = len(input_cnf.clauses)
            var_num = max(input_cnf.vars)
            inputs = np.asarray([clauses_to_matrix(input_cnf.clauses, clause_num, var_num)] * BATCH_SIZE)

            output = predict_fn({"input": inputs})
            policy_probs = output['policy_probabilities']
            sat_prob = output['sat_probabilities'][0]

            best_prob = 0.0
            best_svar = None
            for var in input_cnf.vars:
                for svar in [var, -var]:
                    svar_prob = policy_probs[0][var-1][0 if svar > 0 else 1]
                    #print(svar, svar_prob, best_prob, file=logfile)
                    if svar_prob > best_prob:
                        best_prob = svar_prob
                        best_svar = svar
        if fast or sat_prob < 0.3:
            # Overwriting with JW
            best_svar = jw(input_cnf.clauses)
        return (fast or sat_prob < 0.3), best_svar    

In [28]:
def compute_steps(sats, dpll_cls):
    steps = []
    errors = []
    solved = 0
    for sat in tqdm(sats):
        dpll = dpll_cls()
        res = dpll.run(sat)
        if res is not None:
            steps.append(dpll.number_of_runs)
            errors.append(dpll.number_of_errors)
            solved += 1
    print("Within {} steps solved {} problems out of {}".format(LIMIT_RUNS, solved, len(sats)))
    return steps, errors

In [29]:
def compute_and_print_steps(sats, dpll_cls):
    print(dpll_cls.__name__)
    steps, errors = compute_steps(sats, dpll_cls)
    print("#Sats: {}; avg step: {:.2f}; stdev step: {:.2f}; avg error: {:.2f}; stdev error: {:.2f}".format(
        len(steps), np.mean(steps), np.std(steps), np.mean(errors), np.std(errors)))
    print("Table: ", steps)

In [30]:
def print_all(s, n, m, light=False):
    global S, N, M
    S = s
    N = n # number of clauses
    M = m # number of variables
    
    MAX_TRIES = 100000
    sats = []
    
    random.seed(1)
    np.random.seed(1)
    
    for index in range(MAX_TRIES):
        if len(sats) >= S:
            break
        sat = get_pos_SR(M, M, N)
        sats.append(sat)
    assert len(sats) == S
    print("We have generated {} formulas".format(len(sats)))
    compute_and_print_steps(sats, FastHumbleDPLL)
    compute_and_print_steps(sats, NormalizedJeroslawDPLL)

In [31]:
print_all(10, 1000, 10)

  0%|          | 0/10 [00:00<?, ?it/s]

We have generated 10 formulas
FastHumbleDPLL


100%|██████████| 10/10 [00:27<00:00,  1.33s/it]
100%|██████████| 10/10 [00:00<00:00, 670.23it/s]

Within 1000000 steps solved 10 problems out of 10
#Sats: 10; avg step: 5.30; stdev step: 1.10; avg error: 0.00; stdev error: 0.00
Table:  [7, 5, 4, 7, 4, 6, 6, 4, 5, 5]
NormalizedJeroslawDPLL
Within 1000000 steps solved 10 problems out of 10
#Sats: 10; avg step: 6.00; stdev step: 1.48; avg error: 0.30; stdev error: 0.46
Table:  [8, 6, 9, 4, 5, 6, 7, 5, 5, 5]





In [None]:
print_all(100, 200, 10)

In [None]:
print_all(100, 200, 20)

In [None]:
print_all(100, 500, 30)

In [None]:
print_all(100, 10000, 40, light=True)

In [None]:
print_all(100, 10000, 50, light=True)

In [None]:
print_all(100, 10000, 70, light=True)

In [None]:
print_all(100, 10000, 90, light=True)

In [None]:
print_all(100, 10000, 110, light=True)

In [None]:
print_all(100, 10000, 130, light=True)