In [1]:
import numpy as np
from random import randint, shuffle, sample

import math

In [2]:
# Quick notes on this implementation:
#
#   There are four classes of interest:
#   -> Variable: This class is used to create a variable with a boolean value. Each variable object stores a truthy/falsy 
#      value and a unique ID that distinguishes it from other variable objects. For large SAT expressions, creating an
#      object for a variable and its negation becomes very expensive. Instead one variable object can be used to represent
#      variables and their negations. This is described in more detail in the next section
#
#   -> Clause: This class takes in a list of variables and constructs a clause. Each clause object stores a list of pairs
#      of the form (pointer to variable object, 1 if variable is a negation else 0).
#
#   -> SAT: This class takes in a list of clauses and constructs a clause. Each SAT object stores a list of variables ordered
#      ordered by id and a set of clause objects.
#
#   -> SAT_Factory: This class is used to generate random SAT instances and fixed SAT instances. A SAT factory takes in
#      the maximum number of variables in each clause and the maximum number of clauses to create
#
# A variety of helper methods have been provided along with doc strings that explain their use. Feel free to add more if
# necessary. A file showcasing some of the methods is included in sat_generator_demo.ipynb.

In [3]:
class Variable:
    id = 0
    def __init__(self, value):
        """
        Input(s):
            value: a value for the variable instance where value is meant to be a truthy/falsy value 
                   -> if value is None then it is assumed that the variable is False
        """
        self.id = Variable.id
        self.value = value
        Variable.id += 1
    
    def is_set(self):
        """
        Output(s):
            A boolean that indicates whether this variable is set to True or False
        """
        return self.value is not None
    
    def get_id(self):
        """
        Output(s):
            The id of this variable
        """
        return self.id
        
    def get_val(self):
        """
        Output(s):
            The value of this variable
        """
        return self.value
    
    def num_variables():
        """
        Output(s):
            The total number of variables that have been created
        """
        return Variable.id
        
    def set_val(self, val):
        """
        Output(s):
            None. Sets the value of this variable to val
        """
        self.value = val
        
    def neg(self):
        """
        Output(s):
            The negation of this variable
        """
        return not self.value
    
    def reset_id():
        """
        Output(s):
            None. Sets Variable.id to 0
        """
        Variable.id = 0
        
    def __str__(self, show_values=False, letter='x'):
        """
        Outputs:
            The string representation of this variable
        """
        if show_values:
            return "(" + letter + str(self.id) + ", " + str(self.value) + ")"
        else:
            return letter + str(self.id)

In [4]:
class Clause:
    def __init__(self, variables=[]):
        """
        Input(s):
            variables: a list of tuples of the form (ptr, b)
                       -> ptr is a pointer to a variable object
                       -> b is a boolean that is true if the variable is a negation
        """
        assert all([isinstance(pair, tuple) and isinstance(pair[0], Variable) for pair in variables])
        self.variables = variables
     
    def is_valid(self):
        """
        Output(s):
            A boolean that indicates if the clause is valid. A clause is valid if it has no
            -> contradictions (x1 or ~x1 or ...)
            -> repeated variables (x1 or x1 or ...)
        """
        seen = set()
        for pair in self.variables:
            if (pair[0].get_id(), not pair[1]) in seen or (pair[0].get_id(), pair[1]) in seen:
                return False
            seen.add((pair[0].get_id(), pair[1]))
        return True
    
    def is_satisfied(self):
        """
        Output(s):
            A boolean that indicates whether this clause instance is satisfied with the current assignment
            -> If all variables have None values, then this function returns False by default
        """
        return any(pair[0].get_val() if not pair[1] else pair[0].neg() for pair in self.variables)
    
    def get_variables(self):
        """
        Output(s):
            A set containing all the variables in this clause instance
        """
        return set(pair[0] for pair in self.variables)
    
    def __getitem__(self, idx):
        """
        Input(s):
            The index of the variable of interest
        Output(s):
            The tuple at the specified index of this clause instance
        """
        return self.variables[idx]
        
    def __len__(self):
        """
        Output(s):
            The number of variables in this clause instance
        """
        return len(self.variables)
        
    def __str__(self):
        """
        Output(s):
            The string representation of this clause
        """
        s = ""
        for pair in self.variables:
            if pair[1]:
                var_str = "~" + str(pair[0])
                val_str = str(bool(pair[0].neg()))
            else:
                var_str = str(pair[0])
                val_str = str(bool(pair[0].get_val()))
            s += "(" + var_str + ", "  + val_str + ")" + " or "
        return "{" + s[:len(s) - 4] + "}"

In [5]:
class SAT:
    def __init__(self, clauses=set()):
        """
        Input(s):
            clauses: a set of clause objects
        """
        assert isinstance(clauses, set)
        assert all(isinstance(clause, Clause) for clause in clauses)
        self.clauses = clauses
        self.variables = set()
        for clause in clauses:
            self.variables = self.variables.union(clause.get_variables())
        self.variables = sorted(self.variables, key=lambda v: v.get_id())
        
    def try_assignment(self, assignment):
        """
        Input(s):
            assignemnt: a list of True/False values that will be used to set each variable
        Output(s):
            The number of clauses satisfied by this assignment
        """
        assert isinstance(assignment, list) and len(assignment) == len(self.variables)
        for i in range(len(assignment)):
            self.variables[i].set_val(assignment[i])
        return self.num_satisfied()
    
    def num_satisfied(self):
        """
        Output(s):
            The number of clauses satisfied with the current assignment
        """
        num_satisfied = 0
        for clause in self.clauses:
            if clause.is_satisfied():
                num_satisfied += 1
        return num_satisfied
    
    def is_satisfied(self):
        """
        Output(s):
            A boolean that indicates whether all clauses of this SAT instance are satisfied with the current assignment
        """
        return self.num_satisfied() == len(self.clauses)
    
    def get_variables(self):
        """
        Output(s):
            A set containing all the variables in this SAT instance
        """
        return self.variables
    
    def __len__(self):
        """
        Output(s):
            The number of clauses in this SAT instance
        """
        return len(self.clauses)
        
    def __str__(self):
        """
        Output(s):
            The string representation of this SAT instance
        """
        return "\n".join([str(clause) for clause in self.clauses])

In [5]:
class SAT_Factory:
    def __init__(self, max_num_variables, max_num_clauses, *args):
        """
        Input(s):
            max_num_variables: a cap on the number of variables in each clause
              max_num_clauses: a cap on the number of clauses in the SAT instance
        """
        assert max_num_clauses > 0 and max_num_variables > 0
        self.max_num_variables = max_num_variables
        self.max_num_clauses   = max_num_clauses
        if args:
            self.min_num_variables = args[0]
            self.min_num_clauses = args[1]
    
    def set_num_variables(self, val):
        """
        Input(s):
            val: the new value of max_num_variables
        Output(s):
            None. Sets max_num_variables to val
        """
        assert val > 0
        self.max_num_variables = max_num_variables
            
    def set_num_clauses(self, val):
        """
        Input(s):
            val: the new value of max_num_clauses
        Output(s):
            None. Sets max_num_clauses to val
        """
        assert val > 0
        self.max_num_clauses = max_num_clauses
    
    def generate_random_instance(self):
        """
        Output(s):
             A SAT instance with a random number of variables and clauses
        """
        num_clauses = randint(self.min_num_clauses, self.max_num_clauses)
        num_variables = randint(self.min_num_variables, self.max_num_variables)
        variables = [Variable(None) for i in range(num_variables)]
        variables = [(variables[i], 0) for i in range(len(variables))] + [(variables[i], 1) for i in range(len(variables))]
        shuffle(variables)
        clauses = set()
        for i in range(num_clauses):
            clause_to_add = Clause(sample(variables, randint(1, num_variables)))
            while not clause_to_add.is_valid():
                clause_to_add = Clause(sample(variables, randint(1, num_variables)))
            clauses.add(clause_to_add)
        return SAT(clauses)
    
    def generate_fixed_instance(self, num_vars=None, num_clauses=None):
        """
        Output(s):
            A SAT instance with a fixed number of variables and clauses
        """
        if num_vars == None:
            num_vars = self.max_num_variables
        if num_clauses == None:
            num_clauses = self.max_num_clauses
        variables = [Variable(None) for i in range(num_vars)]
        variables = [(variables[i], 0) for i in range(len(variables))] + [(variables[i], 1) for i in range(len(variables))]
        shuffle(variables)
        clauses = set()
        for i in range(num_clauses):
            clause_to_add = Clause(sample(variables, num_vars))
            while not clause_to_add.is_valid():
                clause_to_add = Clause(sample(variables, num_vars))
            clauses.add(clause_to_add)
        return SAT(clauses)