# SAT Solving in Python
---
**Neal Ó Riain**

# ``` $ whoami```
---

<img src="img/me.jpg" width="35%" align="right"> 
 
 * Former Astrophysicist (🔭, 🚀, 🌝)
 
 
<br> 
 
 
 * Current Data Scientist at Spotify.

# Outline

* What is the SAT problem?
  
* How can we solve it?

* Examples

# Before we begin

Slides: [n-o-r.xyz/sat-slides/](http://www.n-o-r.xyz/sat-slides/)

Notebook: [github.com/neal-o-r/sat-slides/](http://www.github.com/neal-o-r/sat-slides/)

<center>
<H1> What is the SAT problem? <H1>
</center>

# What is SAT?

The SAT problem asks the question, can we find a satisfying solution to a formula of propositional logic?

Hold on, what's a *formula of propositional logic*?

# Propositional logic

Take some Boolean variables, $x_i \in ({\rm True}, {\rm False})$

```python
x_i = True
x_i = False```

We can choose to negate a variable, $\neg x_i$.

```python
not x_i```

And we can choose to combine variables using $\lor$ (or) or $\land$ (and)

```python
x_1 or x_2
x_1 and x_2
# or
any([x_1, x_2])
all([x_1, x_2])
```



# Propositional logic

Using these rules we can make formulas,

$$
(a \land \neg b) \lor (c \land (\neg c \land b))
$$

or, in Python,

```python
(a and (not b)) or (c and ((not c) and b))
```

The SAT question is, can we find a setting for the variables $a$, $b$ and $c$, such that the whole formula evauluates to ${\rm True}$

<center>
<H1> So what, who cares? <H1>
</center>

# SAT Problem

This is a very old problem (dating back to the Greeks), and unsurprisingly it's pretty hard to solve.

The problem is in NP, which means there's no polynomial time algorithm -- solving a formula with $n$ variables will take $\sim O(2^n)$

So in general we can't do better than just guessing...

# SAT Problem

That said, we have good solvers that scale to $10^6$ variables

The problem is NP-complete, which means any other NP problem can be written in a SAT problems

If you want to solve knapsack, TSP, graph-colouring, etc - SAT solvers are a good place to start

# Practical problems

SAT solvers are used for a *lot* of very practical applications

* Planning (how should we partition a resource between these options?)
* Scheduling (what's the best route to take given constraints?)
* Building (where should these components go on a circuit?)

Any situation where you need to satisfy complex constraints, and make a set of binary decisions

Extensions to the algorithm I'm about to show also power SMT solvers (like z3), which allow formal software verification

<center>
<H1> How can we solve this problem? <H1>
</center>

# CNF

To start, it turns out that every formula can be written in *Conjunctive Normal Form* - which is a fancy way of saying an "AND of ORs"

Given a formula:
$$
(a \land \neg b) \lor (c \land (\neg c \land b)) \dots
$$

By applying the rules of logic, we can quickly write it in a form like

$$
(a \lor b \lor \dots ) \land \\
(b \lor \neg d \lor \dots ) \land \\
(\neg c \lor a \lor \dots ) \land \\
\dots
$$

# Definitions
$$
(a \lor b \lor \neg c) \land \\
(\neg a \lor \neg c ) \land \\
(\neg b \lor d \lor a )
$$

Each free symbol is a *variable* - here we have $a$, $b$, $c$, and $d$

Each combination of variable and sign is a *term* - i.e. $a$ and $\neg a$ are different terms.

Each set of $\lor$'s is a *clause*

All the clauses $\land$ together is a *circuit*

# Python definitions

In [1]:
from __future__ import annotations
from typing import List, Tuple, Dict
from itertools import product

In [2]:
Variable = str

class Term:
    def __init__(self, name: str):
        self.name = name
        self.variable = name.split("¬")[-1]
        self.sign = not (name.startswith("¬"))

    def __repr__(self) -> str:
        return self.name

    def assign(self, b: bool) -> bool:
        return b if self.sign else (not b)

In [3]:
# this is the full version we will use
class Term:
    def __init__(self, name: str):
        self.name = name
        self.variable = name.split("¬")[-1]
        self.sign = not (name.startswith("¬"))

    def __repr__(self) -> str:
        return self.name

    def __eq__(self, b: Term) -> bool:
        return b.name == self.name

    def __mul__(self, i: int) -> Term:
        return self if i > 0 else Term("¬" * self.sign + self.variable)

    def __rmul__(self, i: int) -> Term:
        return self.__mul__(i)

    def __neg__(self) -> Term:
        return self.__mul__(-1)

    def __hash__(self) -> int:
        return hash(self.name)

    def assign(self, b: bool) -> bool:
        return b if self.sign else (not b)

In [4]:
print(Term("x"))
print(Term("¬x"))

x
¬x


In [5]:
Term("¬x").assign(True)

False

In [6]:
from pprint import pprint as pretty_print

pprint = lambda x: pretty_print(x, width=20)

# Python definitions

In [7]:
Clause = Tuple[Term, ...]
Circuit = List[Clause]

# Now we can define circuits, for example:
circuit = [(Term("a"), Term("b"), Term("¬c")),
           (Term("¬a"), Term("¬c")),
           (Term("¬b"), Term("d"), Term("a"))]
pprint(circuit)

[(a, b, ¬c),
 (¬a, ¬c),
 (¬b, d, a)]


$$
(a \lor b \lor \neg c) \land \\
(\neg a \lor \neg c ) \land \\
(\neg b \lor d \lor a )
$$

# Assignment

In [8]:
# We also introduce an assigment, a mapping of var -> setting
Assignment = Dict[Variable, bool]

# check if `any` term in a clause is True
def eval_clause(clse: Clause, assigns: Assignment) -> bool:
    return any([t.assign(assigns[t.variable]) for t in clse]) 
 
# check if `all` clauses evaluate to True
def eval_circuit(circ: Circuit, assigns: Assignment) -> bool:
    return all(eval_clause(clse, assigns) for clse in circ)

In [9]:
assignment = {"a": True, "b": True, "c": True, "d": True}

# Does this assignment satisfy our circuit?
eval_circuit(circuit, assignment)

False

# Exhaustive solution

In [10]:
Solution = Tuple[bool, Assignment] 

def exhaustive_search(circ: Circuit) -> Solution:
    variables = {term.variable for clse in circ for term in clse}
        
    for bools in product([True, False], repeat=len(variables)):
        assigns = dict(zip(variables, bools))
        if eval_circuit(circ, assigns):
            return True, assigns
        
    return False, {}

exhaustive_search(circuit)

(True, {'a': True, 'b': True, 'c': False, 'd': True})

$$
(a \lor b \lor \neg c) \land \\
(\neg a \lor \neg c ) \land \\
(\neg b \lor d \lor a )
$$

# Exhaustive solution

Exhaustive solutions are bad news, there are $2^n$ configurations.

$2^4$ is fine, but things get out of hand quickly. $2^{300}$ is more atoms than there are in the universe.

So, can we do better than trying every solution? No in theory, but in practice yes!

<center>
<H1> Doing better <H1>
</center>

# DPLL

The Davis–Putnam–Logemann–Loveland (DPLL) algorithm, invented in 1962

Still the best method for solving SAT problems

Could take $2^n$ steps, but in practice this algorithm is pretty fast, and scales to large problems

# DPLL

The core algorithm is pretty simple, starting with an input circuit $C$

- Make an assumption about one of the variables in $C$, e.g. set $a = {\rm True}$

- Use that assumption to simplify the circuit by removing terms, $C \rightarrow C^\prime$

- Recursively call DPLL on the new smaller, simpler circuit $C^\prime$

# DPLL


1. Make an assumption about one of the variables in $C$, e.g. set $a = {\rm True}$
2. Use that assumption to simplify the circuit by removing terms, $C \rightarrow C^\prime$
3. Recursively call DPLL on the new smaller, simpler $C^\prime$

There are 2 possible outcomes:
- If we delete all the terms in a clause we have a *contradiction* somewhere, we have to change an assumption
- If we end up with an empty circuit, then we've solved the problem, the assumptions are correct

# How do we simplify a circuit?
<br>
$$
C = 
(a \lor b \lor \neg c) \land \\
(\neg a \lor \neg c ) \land \\
(\neg b \lor d \lor a )
$$

Assume that $a = {\rm True}$.

Any clause containing the term $a$ is now ${\rm True}$ by definition, so delete it.

We can ignore $\neg a$ in any clause, since it can't contribute to making the clause ${\rm True}$

$$
C = (\neg c)
$$

For this circuit, if we now set $c = {\rm False}$ we have a solution

# How do we simplify a circuit?

In [11]:
# remove any clause that contains the Term
def rm_clse(circ: Circuit, trm: Term) -> Circuit:
    return [clse for clse in circ if trm not in clse]

# Remove the Term from every clause in the circuit
def rm_term(circ: Circuit, trm: Term) -> Circuit:
    return [tuple(t for t in clse if t != trm) for clse in circ]

pprint(circuit)
rm_term(rm_clse(circuit, Term("a")), Term("¬a"))

[(a, b, ¬c),
 (¬a, ¬c),
 (¬b, d, a)]


[(¬c,)]

$$
(a \lor b \lor \neg c) \land \\
(\neg a \lor \neg c ) \land \\
(\neg b \lor d \lor a )
$$

In [12]:
def simple_dpll(circ: Circuit, assigns: Assignment = {}) -> Solution:
    # empty circuit is True
    if len(circ) is 0:
        return True, assigns
    # empty clause is False
    if any(len(clse) is 0 for clse in circ):
        return False, {}

    # pick a variable
    v = Term(circ[0][0].variable)
    # set it to True, simplify, recurse
    new_circ = rm_term(rm_clse(circ, v), -v)
    sat, pot_assign = simple_dpll(new_circ, {**assigns, **{v.variable: True}})
    if sat:
        return sat, pot_assign
    # or set it to False, simplify, recurse
    new_circ = rm_term(rm_clse(circ, -v), v)
    sat, pot_assign = simple_dpll(new_circ, {**assigns, **{v.variable: False}})
    if sat:
        return sat, pot_assign

    return False, {}

simple_dpll(circuit)

(True, {'a': True, 'c': False})

# DPLL

That's the core DPLL algorithm, but there are 2 other important kinds of simplifications

1. Unit clause propogation
2. Pure term resolution

These simplification methods are pretty straightforward, and implementing them makes the algorithm much more powerful

# Unit Clause Propogation

If a clause contains only one term (like clause 4 below), it's called a *unit clause*

$$
(a \lor d \lor \neg c) \land \\
(\neg a \lor \neg c ) \land \\
(\neg b \lor d \lor a ) \land \\
(b) \land \\
(a \lor \neg d)
$$

That clause has to be ${\rm True}$, or the whole circuit fails.

So set it to ${\rm True}$ (here $b = {\rm True}$), and simplify.

This is super-useful, we often get cascade of unit clauses which makes the circuit much simpler

# Pure Term Resolution

If a term only appears as positive, or negative, it's called *pure*, like $\neg c$ below

$$
(a \lor d\lor \neg c) \land \\
(\neg a \lor \neg c ) \land \\
(\neg b \lor d \lor a ) \land \\
(b) \land \\
(a \lor \neg d)
$$

we can make that term ${\rm True}$ everywhere without effecting anything else.

So set it to ${\rm True}$ (here $c = {\rm False}$), and simplify (in this case remove clauses 1 & 2)

In [13]:
def pure_literal_elimination(circ: Circuit) -> Update:
    terms = [t for clse in circ for t in clse]
    pures = {t for t in terms if -t not in terms}

    assigns = {p.variable: p.sign for p in pures}
    new_circ = circ.copy()
    for p in pures:
        new_circ = rm_clse(new_circ, p)

    return assigns, new_circ


def unit_clause_resolution(circ: Circuit, assigns: Assignment = {}) -> Update:
    if all(len(clse) != 1 for clse in circ):
        return assigns, circ

    units = {clse[0] for clse in circ if len(clse) == 1}
    assigns = {**assigns, **{u.variable: u.sign for u in units}}

    new_circ = circ.copy()
    for u in units:
        new_circ = rm_term(rm_clse(new_circ, u), -u)

    return unit_clause_resolution(new_circ, assigns)


def dpll(circ: Circuit, assigns: Assignment = {}) -> Solution:
    if len(circ) is 0:
        return True, assigns

    if any(len(clse) == 0 for clse in circ):
        return False, {}

    # do resolution
    unit_a, unit_circ = unit_clause_resolution(circ)
    pure_a, resolved_circ = pure_literal_elimination(unit_circ)

    assigns = {**assigns, **pure_a, **unit_a}
    # did resolution solve the problem
    if len(resolved_circ) is 0:
        return True, assigns

    v = Term(circ[0][0].variable)

    new_circ = rm_term(rm_clse(resolved_circ, v), -v)
    sat, pot_assign = dpll(new_circ, {**assigns, **{v.variable: True}})
    if sat:
        return sat, pot_assign

    new_circ = rm_term(rm_clse(resolved_circ, -v), v)
    sat, pot_assign = dpll(new_circ, {**assigns, **{v.variable: False}})
    if sat:
        return sat, pot_assign

    return False, {}

<center>
<H1> Let's see some examples <H1>
</center>

# Let's start with the Graph Colouring problem

Can we colour a graph such that no nodes that are connected are the same colour?

We can easily write this as a SAT problem, for example can we colour 1 and 2 using the colours Red and Blue?

![graph](img/2graph_uncoloured.png)

# Graph Colouring


<img src="img/2graph_uncoloured.png" width="45%" align="right"> 
 
<br> 
$$
({\rm Red}_1 \lor {\rm Blue}_1) \land \\
(\neg {\rm Red}_1 \lor \neg{\rm Blue}_1) \land \\
({\rm Red}_2 \lor {\rm Blue}_2) \land \\
(\neg {\rm Red}_2 \lor \neg {\rm Blue}_2) \land \\
(\neg {\rm Red}_1 \lor \neg {\rm Red}_2) \land \\
(\neg{\rm Blue}_1 \lor \neg{\rm Blue}_2)
$$

In [14]:
def read_file(filename: str) -> list:
    with open(filename) as f:
        txt = f.read()
    return parse_txt(txt)

def parse_txt(txt: str) -> list:
    return {s for s in txt.split("\n") if s != "" and not s.startswith("#")}

def make_circ(txt: str) -> Circuit:
    return [tuple(Term(t) for t in clse.split()) for clse in txt]

In [15]:
colour2 = """
# Assign at least one colour to region 1
Red1 Blue1

# But no more than one colour
¬Red1 ¬Blue1

# Similarly for region 2
Red2 Blue2
¬Red2 ¬Blue2

# Make sure regions 1 and 2 are not coloured the same since they are neighbours
¬Red1 ¬Red2
¬Blue1 ¬Blue2
"""

dpll(make_circ(parse_txt(colour2)))

(True, {'Red1': True, 'Blue1': False, 'Red2': False, 'Blue2': True})

![graph](img/2graph.png)

# So far so good...

We can solve simple graph problems, how about something more complex?

A Sudoku is a graph colouring problem, can you colour this graph using 9 different colours without any duplicates in a square, column, or row?

<img src="img/sudoku.png"></img>

In [16]:
from itertools import combinations

grid = '''\
AA AB AC BA BB BC CA CB CC
AD AE AF BD BE BF CD CE CF
AG AH AI BG BH BI CG CH CI
DA DB DC EA EB EC FA FB FC
DD DE DF ED EE EF FD FE FF
DG DH DI EG EH EI FG FH FI
GA GB GC HA HB HC IA IB IC
GD GE GF HD HE HF ID IE IF
GG GH GI HG HH HI IG IH II
'''

values = list('123456789')

table = [row.split() for row in grid.splitlines()]
points = grid.split()
subsquares = dict()
for point in points:
    subsquares.setdefault(point[0], []).append(point)
# Groups:  rows   + columns           + subsquares
groups = table[:] + list(zip(*table)) + list(subsquares.values())


def assignment_to_str(assigns: Assignment) -> str:
    sq_vals = {k[:2]: k[-1] for k, v in assigns.items() if v}
    nums = "".join(sq_vals[g] for g in grid.split())
    return nums


def show_string(sudoku: str):
    'Display grid from a string (values in row major order)'
    n = 3
    fmt = '|'.join(['%s' * n] * n)
    sep = '+'.join(['-'  * n] * n)
    for i in range(n):
        for j in range(n):
            offset = (i * n + j) * n**2
            print(fmt % tuple(sudoku[offset:offset+n**2]))
        if i != n - 1:
            print(sep)

def exactly_one_of(elements: List[Term]) -> Circuit:
    neg = lambda x: -x
    lt = list(combinations(map(neg, elements), 2))
    return lt + [tuple(elements)]

def column_print(l, cols=4, width=13):
    group = zip(*[l[i::3] for i in range(cols)])
    for row in group:
        print(''.join(" ".join(map(str, word)).ljust(width) for word in row) + "")


sudoku = '53..7....6..195....98....6.8...6...34..8.3..17...2...6.6....28....419..5....8..79'            

# Sudoku as SAT problem

We want to write a SAT Circuit for the sudoku constraints

This will be a **big** circuit, compared with what we've seen

In [17]:
# We want to add the constraint that square (A, A) 
# contains exactly one number from 1-9
rules = exactly_one_of([Term(f"AA{i}") for i in range(1, 10)])                
column_print(rules)

¬AA1 ¬AA2    ¬AA1 ¬AA3    ¬AA1 ¬AA4    ¬AA1 ¬AA5    
¬AA1 ¬AA5    ¬AA1 ¬AA6    ¬AA1 ¬AA7    ¬AA1 ¬AA8    
¬AA1 ¬AA8    ¬AA1 ¬AA9    ¬AA2 ¬AA3    ¬AA2 ¬AA4    
¬AA2 ¬AA4    ¬AA2 ¬AA5    ¬AA2 ¬AA6    ¬AA2 ¬AA7    
¬AA2 ¬AA7    ¬AA2 ¬AA8    ¬AA2 ¬AA9    ¬AA3 ¬AA4    
¬AA3 ¬AA4    ¬AA3 ¬AA5    ¬AA3 ¬AA6    ¬AA3 ¬AA7    
¬AA3 ¬AA7    ¬AA3 ¬AA8    ¬AA3 ¬AA9    ¬AA4 ¬AA5    
¬AA4 ¬AA5    ¬AA4 ¬AA6    ¬AA4 ¬AA7    ¬AA4 ¬AA8    
¬AA4 ¬AA8    ¬AA4 ¬AA9    ¬AA5 ¬AA6    ¬AA5 ¬AA7    
¬AA5 ¬AA7    ¬AA5 ¬AA8    ¬AA5 ¬AA9    ¬AA6 ¬AA7    
¬AA6 ¬AA7    ¬AA6 ¬AA8    ¬AA6 ¬AA9    ¬AA7 ¬AA8    
¬AA7 ¬AA8    ¬AA7 ¬AA9    ¬AA8 ¬AA9    AA1 AA2 AA3 AA4 AA5 AA6 AA7 AA8 AA9


# Sudoku as SAT problem

All Sudoku start with the same constraints, so we need to encode that structure in SAT

In [18]:
sudoku_circuit = make_circ(read_file("code/circuits/sudoku_circuit.txt"))

print(f"Number of clauses: {len(sudoku_circuit)}")
print(f"Number of variables: {len({t.variable for clse in sudoku_circuit for t in clse})}")

Number of clauses: 10560
Number of variables: 729


In [25]:
# for those keeping score:
print(2**729)
print()
print(f"{2**729:.3E}")

2824013958708217496949108842204627863351353911851577524683401930862693830361198499905873920995229996970897865498283996578123296865878390947626553088486946106430796091482716120572632072492703527723757359478834530365734912

2.824E+219


# Sudoku as SAT problem

And then after those general constraints we need the given numbers for that Sudoku

These are essentially a set of *unit clasuses*, for instance here we know that terms AA5 is True

In [19]:
# Let's solve this sudoku
show_string(sudoku)

53.|.7.|...
6..|195|...
.98|...|.6.
---+---+---
8..|.6.|..3
4..|8.3|..1
7..|.2.|..6
---+---+---
.6.|...|28.
...|419|..5
...|.8.|.79


In [20]:
sudoku_circuit = make_circ(read_file("code/circuits/sudoku_circuit.txt"))

solved, assignment = dpll(sudoku_circuit)
show_string(assignment_to_str(assignment))

534|678|912
672|195|348
198|342|567
---+---+---
859|761|423
426|853|791
713|924|856
---+---+---
961|537|284
287|419|635
345|286|179


# Better solver

If we want to use a production solver we have a lot of options, 

`pycosat` is a really powerful, really simple tool

It's about 3,000 times faster than my hand-built solver

In [21]:
from pycosat import itersolve

def pycosat_solve(circ: Circuit) -> Solution:
    variables = {t.variable for c in circ for t in c}
    to_sym = dict(zip(variables, range(1, len(variables) + 1)))
    from_sym = {v: k for k, v in to_sym.items()}

    sym = []
    for c in circ:
        sym.append([to_sym[t.variable] if t.sign else -to_sym[t.variable]
            for t in c])

    sol = next(itersolve(sym), False)
    if sol:
        return True, {from_sym[abs(s)] : s > 0 for s in sol}
    return False, {}

In [22]:
solved, assignment = pycosat_solve(sudoku_circuit)
show_string(assignment_to_str(assignment))

534|678|912
672|195|348
198|342|567
---+---+---
859|761|423
426|853|791
713|924|856
---+---+---
961|537|284
287|419|635
345|286|179


# Last word

You've seen the core algorithm that powers SAT solvers

These solvers have lots of uses and are pretty ubiquitous, for example `pycosat` is the dependency resolver in Conda

Even though this is an NP problem we can scale to big problems

There are lots of practical applications for these methods (planning, optimisation, etc.)

![contact](img/contact-card.png)