In [None]:
#@title
import random
import math
from z3 import *

class Clause:

    def __init__(self, mask_pos, mask_neg, n, allow_duplicates=False):
        # Gets two numbers from [0,...,2^n - 1],
        # Interprets the "1" bits in the first as positive variables
        # And the "1" bits in the second as negative variables
        self.vars_pos = mask_pos
        self.vars_neg = mask_neg
        self.len = n
        self.tautology = False
        # Tautology
        if not allow_duplicates:
            if mask_pos & mask_neg:
                mask_pos = mask_neg = (1 << n) - 1
                self.tautology = True

    def __len__(self):
        return self.len

    def __eq__(self, other):
        pos = self.vars_pos == other.vars_pos
        neg = self.vars_neg == other.vars_neg
        n = self.len == other.len
        return pos and neg and n

    def __neg__(self):
        # Given a clause, returns the "opposite clause",
        # I.e., map x_i to !(x_i)
        return Clause(self.vars_neg, self.vars_pos, self.len)

    def __repr__(self):
        if self.tautology:
            return "True"
        s_pos = bin(self.vars_pos)[2:]
        s_neg = bin(self.vars_neg)[2:]
        st = ""
        for i in range(self.len):
            if len(s_pos) > i and s_pos[len(s_pos) - i - 1] == '1':
                st += "x" + str(i) + ", "
            if len(s_neg) > i and s_neg[len(s_neg) - i - 1] == '1':
                st += "-x" + str(i) + ", "
        return st[:-2]

    def __call__(self, *arg):
        # Gets two numbers from [0,...,2^n - 1],
        # Interprets the "1" bits in the first as positive variables
        # And the "1" bits in the second as negative variables
        if len(arg) == 2:
            assignment_pos = arg[0]
            assignment_neg = arg[1]

        if len(arg) == 1:
            assignment_pos = arg[0][0]
            assignment_neg = arg[0][1]

        pos = assignment_pos & self.vars_pos
        neg = assignment_neg & self.vars_neg
        return bool(pos or neg)

    @staticmethod
    def random(n, k=2, allow_duplicates=False, bits=None):
        # Draw a random k-variable clause over n variables
        if bits == None:
            bits = random.sample(range(n), k)
        else:
            bits = random.sample(bits, k)
        mask_pos = 0
        mask_neg = 0
        for bit in bits:
            if random.randint(0, 1) == 1:
                mask_pos += (1 << bit)
            else:
                mask_neg += (1 << bit)
        cl = Clause(mask_pos, mask_neg, n, allow_duplicates=allow_duplicates)
        return cl

def check_dubp(clause, hash_table, clause_dict):
    enc = lambda x: clause_dict[x.replace(' ', '')]

    str_clause = str(clause)
    str_clause = str_clause.split(",")

    sorted_encoding = str(sorted(list(map(enc, str_clause))))

    if sorted_encoding not in hash_table:
        hash_table[sorted_encoding] = ''
        return clause, hash_table
  
    return None, hash_table

class Formula:

    def __init__(self, clauses):
        # Gets a sequence of clauses and defines them as the formula
        assert len(set(len(clause) for clause in clauses)) <= 1
        self.clauses = [clause for clause in clauses]
        self.nof_vars = 0

        if self.clauses:
            self.nof_vars = len(self.clauses[0])

    @staticmethod
    def random(n, k, allow_duplicates=False, t=2, n_used_vars=None):
        # Draw a random t-CNF formula over n variables with k clauses
        clause_dict = {f'x{i}': i for i in range(n)}
        clause_dict.update({f'-x{i}': n+i for i in range(n)})

        cl = []
        hash_table = {}
        i = 0
        if n_used_vars != None:
            bits = random.sample(range(n), n_used_vars)
        else:
            bits = None
        while i < k:
            new_cl = Clause.random(n, t, allow_duplicates=allow_duplicates, bits=bits)
            if allow_duplicates:
                cl.append(new_cl)
                i += 1
            else:
                new_cl, hash_table = check_dubp(new_cl, hash_table, clause_dict)
                if new_cl is not None:
                    cl.append(new_cl)
                    i += 1
        return Formula(cl)

    @staticmethod
    def random_assignment(n):
        # Draw a random assignment over n variables
        # We draw a random positive assignment and
        # The negative assignment is the complement
        pos = random.randint(0, (1 << n) - 1)
        neg = (1 << n) - 1 - pos
        return (pos, neg)

    @staticmethod
    def random_hashed(n):
        # Draw random log(n) bit boolean test had
        # And set i-th bit of random assignment to be
        # The inner product <had, i> over F2
        had = random.randint(0, n)

        def ip(a, b):
            return bin(a & b).count('1') % 2

        pos = neg = 0
        for i in range(n):
            if ip(i, had):
                pos += (1 << i)
            else:
                neg += (1 << i)
        return (pos, neg)

    def __repr__(self):
        st = ""
        for cl in self:
            st += str(cl) + " and "
        return st[:-5]

    def __iter__(self):
        self.iter_idx = 0
        return self

    def __len__(self):
        return len(self.clauses)

    def __getitem__(self, i):
        assert 0 <= i < len(self)

        return self.clauses[i]

    def __next__(self):
        # Iteration method simply iterates over all clauses
        if self.iter_idx < len(self):
            res = self[self.iter_idx]
            self.iter_idx += 1
            return res
        raise StopIteration

    def __call__(self, *arg):
        # Evaluate a formula with a given assignment
        # If any clause is unsatisfied return False
        # Else, return True
        for clause in self:
            if not clause(*arg):
                return False
        return True

    def approximate_sat(self, k=1000, avg=True, hashed=False):
        # If avg, compute avg nof satisfied clauses by random assignment
        # Otherwise, compute maximum nof satisfied clauses
        max_cl = 0
        total = 0
        for i in range(k):
            sat = 0
            if hashed:
                ass = Formula.random_hashed(self.nof_vars)
            else:
                ass = Formula.random_assignment(self.nof_vars)
            for cl in self:
                if cl(ass):
                    sat += 1
            if avg:
                total += sat
            if sat > max_cl:
                max_cl = sat
        if avg:
            return total / (k * len(self))
        return max_cl / len(self)

    def approximate_count(self, k=1000, hashed=False):
        # Try k random assignments, return approximate acceptance probability
        counter = 0
        for i in range(k):
            if hashed:
                ass = Formula.random_hashed(self.nof_vars)
            else:
                ass = Formula.random_assignment(self.nof_vars)
            if self(ass):
                counter += 1
        return counter / k

    def brute_force(self, count=False):
        # Find a solution using naive brute force search
        # If count=True, count all solutions
        counter = 0
        start = 0
        end = (1 << self.nof_vars) - 1

        while start <= end:
            if self(start, end):
                if not count:
                    return (start, end)
                counter += 1
            if self(end, start):
                if not count:
                    return (end, start)
                counter += 1
            start += 1
            end -= 1

        if count:
            return counter
        return None

In [None]:
def encode_to_token(r, req_len):
  input_idx = []
  digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

  i = 0
  while i < len(r):
    c = r[i]
    if c in ['(', ')', ':', 's', 'u']:
      input_idx.append(token_mappings[c])
      i += 1
    elif c == 'x':
      if r[i+2] in digits:
        input_idx.append(token_mappings[r[i:i+3]])
        i += 3
      else:
        input_idx.append(token_mappings[r[i:i+2]])
        i += 2
    elif c == 'N':
      if r[i+6] in digits:
        input_idx.append(token_mappings[r[i:i+8]])
        i += 8
      else:
        input_idx.append(token_mappings[r[i:i+7]])
        i +=7
    else:
      i += 1
    
  if len(input_idx) != req_len:
    pad_len = req_len - len(input_idx)
    pad_list = [token_mappings['#']] * pad_len
    input_idx = pad_list + input_idx

  return input_idx

def decode_to_CNFs(input_idx):
  de_r = []
  for t in input_idx:
    de_r.append(rev_token_mappings[t])
  return " ".join(de_r)


In [None]:
import sympy 
def check_dubp_formula(fomula, hash_table, clause_dict, allow_duplicates=False):
  
  def enc(x):
    x = x.replace(' ', '')
    
    return clause_dict[x]

  list_of_clauses = str(fomula).split('and')

  formula_integers = []
  for cl in list_of_clauses:
    if allow_duplicates:
        sorted_cl = list(map(enc, cl.split(',')))
    else:
        sorted_cl = sorted(list(map(enc, cl.split(','))), reverse=True)
    cl_integer = int("".join(list(map(str, sorted_cl))))
    formula_integers.append(cl_integer)
  
  if not allow_duplicates:
    formula_integers = sorted(formula_integers)

  formula_str = str(formula_integers)

  if formula_str not in hash_table:
    hash_table[formula_str] = ''
    return fomula, hash_table

  return None, hash_table




In [None]:
from sympy import to_cnf, Or, Not, And, simplify
from z3 import *

# Convert Z3 to SymPy
def z3_to_sympy(z3_expr):
    if is_bool(z3_expr):
        return z3_expr
    elif is_not(z3_expr):
        return Not(z3_to_sympy(z3_expr.arg(0)))
    elif is_and(z3_expr):
        return And([z3_to_sympy(c) for c in z3_expr.children()])
    elif is_or(z3_expr):
        return Or([z3_to_sympy(c) for c in z3_expr.children()])

def create_random_sat(n_vars, n_cls, clause_dict, hash_table={}, sim=False, allow_duplicates=False, n_used_vars=None, verbose=True):
    F = None
    while F is None:
        F = Formula.random(n_vars, n_cls, allow_duplicates=allow_duplicates, n_used_vars=n_used_vars)
        F, hash_table = check_dubp_formula(F, hash_table, clause_dict, allow_duplicates=allow_duplicates)

    s = Solver()

    vars = [Bool(f"x{i}") for i in range(n_vars)]

    cls = []

    # Create the main Z3 expression using And
    for i in range(len(F)):
        cl = []
        for v in str(F.clauses[i]).split(','):
            var_id = int(v.split('x')[1])
            tmp = vars[var_id]
            if '-' in v:
                tmp = Not(tmp)
            cl.append(tmp)
        cls.append(Or(*cl))
    z3_expr = And(*cls)

    if verbose:
        print(z3_expr)

    if sim:
        # Convert to SymPy expression
        sympy_expr = z3_to_sympy(z3_expr)

        # Simplify using SymPy
        simplified_sympy_expr = simplify(sympy_expr)

        # Convert back to Z3 expression
        z3_expr = simplify(simplified_sympy_expr)

        s.add(z3_expr)
    else:
        s.add(And(*cls))

    s_str = str(s)
    record = "".join(s_str.split("\n     "))[1:-1]

    label = s.check()
    label = "s" if str(label) == 'sat' else "u"

    record = record[:-1].replace('And(', '').replace('Or', '') + f':{label}'
    # record = str(simplified_sympy_expr) + ':' + label


    return label, record, hash_table


## Formulas with 5 variables

In [None]:
nvars=5

In [None]:
clause_dict = {f'x{i}': i for i in range(nvars)}
clause_dict.update({f'-x{i}': nvars+i for i in range(nvars)})

In [None]:
token_mappings = {f'x{i}': i for i in range(nvars)}
token_mappings.update({f'Not(x{i})': nvars+i for i in range(nvars)})
token_mappings.update({'(': 2*nvars, ')': 2*nvars + 1, ':': 2*nvars + 2, 's': 2*nvars + 3, 'u': 2*nvars + 4})
token_mappings.update({'#': 2*nvars + 5}) # this is for padding, meaning DONT CARE
print(token_mappings)

rev_token_mappings = {val : key for key, val in token_mappings.items()}

In [None]:
C = Clause.random(3)
print(str(C))
C = Clause(2,2,3)
print(str(C))

In [None]:
F = Formula.random(5, 10, n_used_vars=3)
check_dubp_formula(F, {}, clause_dict)

### Generate Training Dataset

In [None]:
import numpy as np
from tqdm.auto import trange

n_sat = 0
n_unsat = 0

cnf_tokens = None
hash_table = {}

n_ones = 1kjjj000

pbar = trange(20000000)
for _ in pbar:
    sat, line, hash_table = create_random_sat(nvars, 10, clause_dict, sim=False, hash_table=hash_table, 
                                              allow_duplicates=True, verbose=False)

    encoding = np.array(encode_to_token(line,42))

    encoding = encoding[None]

    if cnf_tokens is None:
      cnf_tokens = encoding
    else:
      if sat == 's':
        if n_sat == n_ones:
          continue
        n_sat += 1
      else:
        if n_unsat == n_ones:
          continue
        n_unsat += 1
      
      cnf_tokens = np.concatenate([cnf_tokens, encoding], axis=0)

    if n_sat == n_ones and n_unsat == n_ones:
      break
    
    pbar.set_description(f"n_sat: {n_sat}, n_unsat: {n_unsat}")

print(n_sat, n_unsat)
print(len(hash_table.keys()))

In [None]:
np.save("data/cnf_tokens_1M.npy", cnf_tokens)

### Generate Analysis Dataset

In [None]:
import numpy as np
from tqdm.auto import trange

n_sat = 0
n_unsat = 0

cnf_tokens = None
hash_table = {}

n_ones = 100000

pbar = trange(20000000)
for _ in pbar:
    sat, line, hash_table = create_random_sat(nvars, 10, clause_dict, sim=False, hash_table=hash_table, 
                                              allow_duplicates=True, verbose=False)

    encoding = np.array(encode_to_token(line,42))

    encoding = encoding[None]

    if cnf_tokens is None:
      cnf_tokens = encoding
    else:
      if sat == 's':
        if n_sat == n_ones:
          continue
        n_sat += 1
      else:
        if n_unsat == n_ones:
          continue
        n_unsat += 1
      
      cnf_tokens = np.concatenate([cnf_tokens, encoding], axis=0)

    if n_sat == n_ones and n_unsat == n_ones:
      break
    
    pbar.set_description(f"n_sat: {n_sat}, n_unsat: {n_unsat}")

print(n_sat, n_unsat)
print(len(hash_table.keys()))

In [None]:
np.save("data/cnf_tokens_100K.npy", cnf_tokens)