In [1]:
import sys
import time
import signal
import itertools
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
import pyeda.inter as inter
import pyeda.boolalg as boolalg
from tqdm.notebook import trange
from graphviz import Digraph, Source
from IPython.display import SVG, HTML, display
from pprint import pprint

infty = float('inf')

## Class
### ODD

In [2]:
class ODD():
    def __init__(self):
        self.root = None
        self.cache = dict()
        self.bdd = None
        self.svg_graph = None
        self.eda_vars = dict() # only for restrict bdd

    def __str__(self):
        return str(boolalg.bdd.bdd2expr(self.root.eda_expr))

    def display_expr(self):
        if self.bdd is not None:
            dot_odd = self.bdd.to_dot()
            display(Source(dot_odd))

    def display_graph(self):
        if not self.svg_graph:
            self.make_graph()
        display(SVG(self.svg_graph))
                

class Node:
    def __init__(self, interval=None):
        self.child = dict()
        self.parent = dict()
        self.interval = interval
        self.id = f"n{id(self)}"
        self.eda_expr = None

    def __str__(self):
        return f"{self.child.__len__()} children | {self.parent.__len__()} parents | {self.interval:.2f}"
    
    def add_child(self, other, label):
        self.child.update({label: other})
        other.parent.update({label: self})

    def pop_child(self, label):
        child = self.child.pop(label)
        child.parent.pop(label)
        return child

    def has_child(self):
        return len(self.child) > 0

class Interval:
    def __init__(self, low, high, closed_left=True, closed_right=True):
        self.low = low
        self.high = high
        self.left = closed_left
        self.right = closed_right
    
    def __str__(self):
        left = "[" if self.left else "("
        right = "]" if self.right else ")"
        return f"{left}{self.low}; {self.high}{right}"

    def __format__(self, __format_spec):
        left = "[" if self.left else "("
        right = "]" if self.right else ")"
        if __format_spec:
            return f"{left}{self.low:{__format_spec}}; {self.high:{__format_spec}}{right}"
        else:
            return f"{left}{self.low}; {self.high}{right}"

    def __add__(self, value):
        return Interval(self.low + value, self.high + value, self.left, self.right)

    def __sub__(self, value):
        return Interval(self.low - value, self.high - value, self.left, self.right)

    def __contains__(self, value):
        down = value >= self.low if self.left else value > self.low
        up = value <= self.high if self.right else value < self.high
        return down & up
    
    def intersect(self, other):
        if self.low < other.low:
            self.low = other.low
            self.left = other.left
        elif self.low == other.low:
            self.left &= other.left
        
        if self.high > other.high:
            self.high = other.high
            self.right = other.right
        elif self.high == other.high:
            self.right &= other.right

        return self

### Neural networks

In [3]:
class StepFunction(Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.where(input>=0, torch.tensor(1.0), torch.tensor(0.0))
        ctx.save_for_backward(input)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        return grad_input
    
class StepActivation(nn.Module):
    def forward(self, input):
        a =  StepFunction.apply(input)
        return a
    
def layers2nn(layers):
    a = []    
    for x, y in zip(layers[:-1], layers[1:]):
        a.append(nn.Linear(x, y))
        a.append(StepActivation())
    return nn.Sequential(*a)

## Utilities

### Decorators

In [4]:
def handler(sig, frame):
    raise Exception("Function takes too much time")

def timelimit(maxtime=100):
    def inner(func):
        def wrapper(*args, **kwargs):
            signal.signal(signal.SIGALRM, handler)
            signal.alarm(maxtime)
            try:
                res = func(*args, **kwargs)
            except Exception as exc:
                print(exc)
            else:
                signal.alarm(-1)
                return res
        return wrapper
    return inner

def timecounter(message=None):
    def inner(func):
        def wrapper(*args, **kwargs):
            start = time.time()
            res = func(*args, **kwargs)
            deltatime = (time.time() - start)*1000
            if message:
                try:
                    print(message.format(deltatime))
                except Exception:
                    print(f"Done in {deltatime}ms")
            elif message == False:
                return res, deltatime
            else:
                print(f"Done in {deltatime}ms")
            return res
        return wrapper
    return inner

### Cache related functions

In [5]:
class ODD(ODD):
    def clear_cache(self):
        self.cache.clear()

    def store_in_cache(self, k, node):
        if self.cache.get(k):
            self.cache[k].append(node)
        else:
            self.cache.update({k: [node]})

    def find_in_cache(self, k, value):
        cache_line = self.cache.get(k)
        if cache_line:
            for node in cache_line:
                if value in node.interval:
                    return node
        return None

    # TODO save count ? or count as we build it ?
    def cache_total_node_count(self):
        node_count = 0
        for x in self.cache.values():
            node_count+=x.__len__()
        return node_count

    def rebuild_cache(self, odd):
        self.clear_cache()
        self.cache[0] = [odd]
        has_child = odd.has_child()
        depth = 0
        next_layer = set()
        while has_child:
            layer = self.cache[depth]
            for node in layer:
                for child in node.child.values():
                    next_layer.add(child)

            depth+=1
            self.cache[depth] = list(next_layer)
            next_layer.clear()
            has_child = self.cache[depth][0].has_child()
        
        return self.cache

### Draw graph

In [6]:
class ODD(ODD):
    def make_graph(self):
        if self.svg_graph is None:
            self._make_graph()
        return self.svg_graph
    
    def _make_graph(self):
        dot = Digraph()
        for cache_line in self.cache.values():
            for node in cache_line:
                dot.node(node.id, f"{node.interval:.2f}")
                for e, child in node.child.items():
                    dot.edge(node.id, child.id, str(e))

        self.svg_graph = dot._repr_image_svg_xml()

def str_bdd_tree(bdd):
    def _build_str(node):
        if node:
            return f"['root': {node.root}, 'hi': {_build_str(node.hi)}, 'lo': {_build_str(node.lo)}]"
    print(_build_str(bdd.node))

## Build ODD

In [7]:
class ODD(ODD):
    @timecounter(message=False)
    def timed_build_odd_rec(self, weights, threshold):
        return self._build_odd_rec(weights, threshold)

    def build_odd_rec(self, weights, threshold, label=""):
        self.weights = weights
        self.label = label
        self.root = self._build_odd_rec(weights, threshold)
        self.bdd = self.root.eda_expr

    @timelimit(600)
    def _build_odd_rec(self, weights, threshold):
        n = len(weights)
        zero_sink = Node(Interval(-infty, threshold, closed_left=False, closed_right=False))
        zero_sink.eda_expr = 0
        self.zero = zero_sink
        self.store_in_cache(n, zero_sink)

        one_sink = Node(Interval(threshold, infty, closed_right=False))
        one_sink.eda_expr = 1
        self.store_in_cache(n, one_sink)
        self.one = one_sink

        return self.build_sub_odd_rec(0, 0)

    def build_sub_odd_rec(self,  k, v):
        node = Node(Interval(-infty, infty, closed_left=False, closed_right=False))
        eda_var = inter.bddvar(f"{self.label}{k}")
        self.eda_vars[eda_var] = self.svg_graph
        weight = self.weights[k]
        for e in {0, 1}:
            w = e*weight
            v_child = v + w
            child = self.find_in_cache(k+1, v_child)
            if child is None:
                child = self.build_sub_odd_rec(k+1, v_child)
            node.add_child(child, e)
            node.interval.intersect(child.interval-w)
        node.eda_expr = inter.ite(eda_var, node.child[1].eda_expr, node.child[0].eda_expr)
        self.store_in_cache(k, node)
        return node

In [8]:
def layers2odds(layers):
    odds_layers = []
    for i, layer in enumerate(layers):
        odds = []
        label = f"l{i}_i" if i else "i"
        for weights, bias in zip(*layer):
            odd = ODD()
            odd.build_odd_rec(weights, -bias, label)
            odds.append(odd)
            odd.clear_cache()
        odds_layers.append(odds)
    return odds_layers

def combine_odds(odds):
    prev_layer = [(odd.bdd, odd.eda_vars) for odd in odds[0]]
    p_vars = prev_layer[0][1]
    for odds_next in odds[1:]:
        next_layer = [(odd.bdd, odd.eda_vars) for odd in odds_next]
        
        switch_layer = []
        for n_bdd, n_vars in next_layer:
            res_bdd = n_bdd.compose({n_var: p_bdd for (p_bdd, _), n_var in zip(prev_layer, n_vars)}) 
            switch_layer.append((res_bdd, p_vars))

        prev_layer = switch_layer
    
    return prev_layer[0]

def compile_nn(net, verbose=False):
    params = list(net.parameters())
    if verbose:
        print("converting to ODDs : ", end="")
        start_convert = time.perf_counter()
    odds = layers2odds(zip(params[::2], params[1::2]))
    if verbose:
        print(f"DONE ({time.perf_counter()-start_convert:1.2e})\ncombining ODDs : ", end="")
        start_combine = time.perf_counter()
    res = combine_odds(odds)
    if verbose:
        print(f"DONE ({time.perf_counter()-start_combine:1.2e})")
    return res


## Tests

In [9]:
# TODO test limit case
TEST = False

def test_odd_build(n_weights):
    neuron = nn.Linear(n_weights,1)

    odd_test = ODD()
    odd_test.build_odd_rec(neuron.weight[0], -neuron.bias[0])
    for p in itertools.product([0, 1],repeat=n_weights):
        s = 0
        for w, e in zip(neuron.weight[0], p):
            s+=w*e
        eval_sum = s + neuron.bias[0] >= 0
        eval_odd = bool(odd_test.root.eda_expr.restrict({k: v for k,v in zip(odd_test.eda_vars,p)}))
        assert(eval_odd == eval_sum)
    del odd_test

def test_combine_odd(layers):
    net = layers2nn(layers)
    bdd, bdd_vars = compile_nn(net)
    # try:
    for p in itertools.product([0, 1],repeat=len(bdd_vars)):
        eval_bdd = bdd.restrict({k: v for k,v in zip(bdd_vars,p)})
        eval_bdd = bool(eval_bdd)
        eval_net = bool(net.forward(torch.Tensor(p)))
        assert(eval_bdd == eval_net)
    # except AssertionError:
    #     print(bdd, bdd_vars)
    #     print(p, "| bdd", eval_bdd, "| net", eval_net)
    #     params = list(net.parameters())
    #     odds = layers2odds(zip(params[::2], params[1::2]))
    #     print(params)
    #     display(Source(bdd.to_dot()))
    #     for x in odds:
    #         print("_"*30)
    #         for odd in x:
    #             display(Source(odd.bdd.to_dot()))
    #     return True

if TEST:
    for n_weights in trange(1, 11): 
        for _ in range(100):
            test_odd_build(n_weights)
    
    for _ in trange(1000):
        if test_combine_odd([5,5,5,5,1]):
            break

## Main

In [10]:
layers_nn = [15,1]

net = layers2nn(layers_nn)
bdd, bdd_vars = compile_nn(net, verbose=True)
if layers_nn[0] < 10:
    display(Source(bdd.to_dot()))

converting to ODDs : DONE (2.02e-01)
combining ODDs : DONE (2.66e-06)


## Approx BDD

### General case

In [11]:
def l(cmax, node):
    return node.root if node.root > 0 else cmax+1

memo_zp = dict()
def zp(node, label):
    a = memo_zp.get(node)
    if a:
        return a.get(label)
    return None

# pre compute zp by bottom-up algorithm
def compute_zp(bdd, c):
    for x in bdd.dfs_postorder():
        if x.root < 0: # if leaf (0/1)
            memo_zp[x] = {c: x}
            continue
        memo_zp[x] = {**memo_zp[x.lo], l(c, x): x}

def init_find_incl(bdd, c):
    R = []
    nodes = {k:[] for k in range(1,c+2)}
    for x in bdd.dfs_preorder():
        R.append((x,x))
        if x.root > 0:
            nodes[x.root].append(x)
        else:
            nodes[c+1].append(x)
    R.append((nodes[c+1][0], nodes[c+1][1]) if nodes[c+1][0] == -2 else (nodes[c+1][1], nodes[c+1][0]))

    return nodes, R

def find_incl(bdd, c):
    nodes, R = init_find_incl(bdd, c)
    print(len(R))
    for k in range(c, 0, -1):
        for u,v in itertools.combinations(nodes[k], 2):
            if (u.lo, zp(v.lo, l(c, u.lo))) in R and (u.hi, zp(v.hi, l(c, u.hi))) in R:
                R.append((u,v))
            if (v.lo, zp(u.lo, l(c, v.lo))) in R and (v.hi, zp(u.hi, l(c, v.hi))) in R:
                R.append((v, u))

    print(len(R)) 

compute_zp(bdd, len(bdd_vars))
find_incl(bdd,len(bdd_vars))

428
432


### Monotone

In [58]:
layers_nn_mono = [15, 1]

while True:
    net_mono = layers2nn(layers_nn_mono)
    for i, params in enumerate(net_mono.parameters()):
        if i%2==0:
            params.data = abs(params.data)
        else:
            params.data = -10 * abs(params.data)

    bdd_mono, bdd_vars_mono = compile_nn(net_mono)
    if not bdd_mono.is_one() and not bdd_mono.is_zero():
        break

memo_F = dict()
def compute_F(bdd): # ZDD algo -> adapt to BDD (?)
    for x in bdd.dfs_postorder():
        if x.root > 0:
            memo_F[x] = {frozenset(S | {x.root}) for S in memo_F[x.hi]} | memo_F[x.lo]
        elif x.root == -2: # case True
            memo_F[x] = {frozenset()}
        elif x.root == -1: # case False
            memo_F[x] = set()

def F(v):
    return memo_F[v]

def argmin(d):
    return min(d, key=d.get)

def find_incl_mono(bdd, c):
    special_nodes = set(memo_zp[bdd.node].values()) # nodes reached by 0-edge from the root
    nodes, _ = init_find_incl(bdd, c)
    nodes[0] = dict() # simplify code for P0 and P1 computation
    
    sup = dict() # init ??

    for k in range(1, c+1):
        for u in set(nodes[k]) - special_nodes:
            candidates = set()
            P0 = {v for v in nodes[k-1] if v.lo is u}
            P1 = {v for v in nodes[k-1] if v.hi is u}

            for p in P0 | P1:
                q = p
                zp_q = zp(q, l(c, u))
                while u is zp_q:
                    q = sup[q]
                    zp_q = zp(q, l(c, u))
                
                candidates.add(zp_q) # zp can return None
                if zp_q is None:
                    sys.exit("zp_q is None")

            for p in P1:
                q = p
                zp_q1 = zp(q.hi, l(c, u))
                # print(zp_q1, q.hi.root, l(c,u), sup)
                while u is zp_q1  and q not in special_nodes:
                    q = sup[q]
                    zp_q1 = zp(q.hi, l(c, u))
                if u is not zp_q1:
                    candidates.add(zp_q1) # zp_q1 can be None
            
            if candidates: 
                sup[u] = argmin({v: len(F(v)) for v in candidates}) # KeyError if None in candidates
            
    print(sup)

memo_zp.clear()
compute_F(bdd_mono)
compute_zp(bdd_mono, len(bdd_vars_mono))
find_incl_mono(bdd_mono, len(bdd_vars_mono))

KeyError: <pyeda.boolalg.bdd.BDDNode object at 0x7f40318cfc10>