In [1]:
import sys
import os
import time
import signal
import itertools
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
import pyeda.inter as inter
import pyeda.boolalg as boolalg
from tqdm.notebook import trange
from graphviz import Digraph, Source
from IPython.display import SVG, HTML, display
from pprint import pprint

sys.path.append(os.path.join(os.path.abspath(""), ".."))

from compiling_nn.build_odd import ODD, compile_nn
from utils.custom_activations import StepActivation

infty = float('inf')

## Tests

In [2]:
def layers2nn(layers):
    a = []    
    for x, y in zip(layers[:-1], layers[1:]):
        a.append(nn.Linear(x, y))
        a.append(StepActivation())
    return nn.Sequential(*a)

# TODO test limit case
TEST = True

def test_odd_build(n_weights):
    neuron = nn.Linear(n_weights,1)

    odd_test = ODD()
    odd_test.build_odd_rec(neuron.weight[0], -neuron.bias[0])
    for p in itertools.product([0, 1],repeat=n_weights):
        s = 0
        for w, e in zip(neuron.weight[0], p):
            s+=w*e
        eval_sum = s + neuron.bias[0] >= 0
        eval_odd = bool(odd_test.root.eda_expr.restrict({k: v for k,v in zip(odd_test.eda_vars,p)}))
        assert(eval_odd == eval_sum)
    del odd_test

def test_combine_odd(layers):
    net = layers2nn(layers)
    bdd, bdd_vars = compile_nn(net)
    # try:
    for p in itertools.product([0, 1],repeat=len(bdd_vars)):
        eval_bdd = bdd.restrict({k: v for k,v in zip(bdd_vars,p)})
        eval_bdd = bool(eval_bdd)
        eval_net = bool(net.forward(torch.Tensor(p)))
        assert(eval_bdd == eval_net)
    # except AssertionError:
    #     print(bdd, bdd_vars)
    #     print(p, "| bdd", eval_bdd, "| net", eval_net)
    #     params = list(net.parameters())
    #     odds = layers2odds(zip(params[::2], params[1::2]))
    #     print(params)
    #     display(Source(bdd.to_dot()))
    #     for x in odds:
    #         print("_"*30)
    #         for odd in x:
    #             display(Source(odd.bdd.to_dot()))
    #     return True

if TEST:
    # for n_weights in trange(1, 11): 
    #     for _ in range(100):
    #         test_odd_build(n_weights)
    
    for _ in trange(1000):
        if test_combine_odd([5,5,5,5,1]):
            break

  0%|          | 0/1000 [00:00<?, ?it/s]

AttributeError: 'NoneType' object has no attribute 'root'

## Main

In [None]:
layers_nn = [15,1]

net = layers2nn(layers_nn)
bdd = compile_nn(net, verbose=True)
if layers_nn[0] < 10:
    display(Source(bdd.to_dot()))

converting to ODDs : DONE (1.31e-01)
combining ODDs     : DONE (8.11e-06)


TypeError: cannot unpack non-iterable BinaryDecisionDiagram object

## Approx BDD

### General case

In [None]:
def l(cmax, node):
    return node.root if node.root > 0 else cmax+1

memo_zp = dict()
def zp(node, label):
    a = memo_zp.get(node)
    if a:
        return a.get(label)
    return None

# pre compute zp by bottom-up algorithm
def compute_zp(bdd, c):
    for x in bdd.dfs_postorder():
        if x.root < 0: # if leaf (0/1)
            memo_zp[x] = {c: x}
            continue
        memo_zp[x] = {**memo_zp[x.lo], l(c, x): x}

def init_find_incl(bdd, c):
    R = []
    nodes = {k:[] for k in range(1,c+2)}
    for x in bdd.dfs_preorder():
        R.append((x,x))
        if x.root > 0:
            nodes[x.root].append(x)
        else:
            nodes[c+1].append(x)
    R.append((nodes[c+1][0], nodes[c+1][1]) if nodes[c+1][0] == -2 else (nodes[c+1][1], nodes[c+1][0]))

    return nodes, R

def find_incl(bdd, c):
    nodes, R = init_find_incl(bdd, c)
    print(len(R))
    for k in range(c, 0, -1):
        for u,v in itertools.combinations(nodes[k], 2):
            if (u.lo, zp(v.lo, l(c, u.lo))) in R and (u.hi, zp(v.hi, l(c, u.hi))) in R:
                R.append((u,v))
            if (v.lo, zp(u.lo, l(c, v.lo))) in R and (v.hi, zp(u.hi, l(c, v.hi))) in R:
                R.append((v, u))

    print(len(R)) 

compute_zp(bdd, len(bdd_vars))
find_incl(bdd,len(bdd_vars))

428
432


### Monotone

In [None]:
layers_nn_mono = [15, 1]

while True:
    net_mono = layers2nn(layers_nn_mono)
    for i, params in enumerate(net_mono.parameters()):
        if i%2==0:
            params.data = abs(params.data)
        else:
            params.data = -10 * abs(params.data)

    bdd_mono, bdd_vars_mono = compile_nn(net_mono)
    if not bdd_mono.is_one() and not bdd_mono.is_zero():
        break

memo_F = dict()
def compute_F(bdd): # ZDD algo -> adapt to BDD (?)
    for x in bdd.dfs_postorder():
        if x.root > 0:
            memo_F[x] = {frozenset(S | {x.root}) for S in memo_F[x.hi]} | memo_F[x.lo]
        elif x.root == -2: # case True
            memo_F[x] = {frozenset()}
        elif x.root == -1: # case False
            memo_F[x] = set()

def F(v):
    return memo_F[v]

def argmin(d):
    return min(d, key=d.get)

def find_incl_mono(bdd, c):
    special_nodes = set(memo_zp[bdd.node].values()) # nodes reached by 0-edge from the root
    nodes, _ = init_find_incl(bdd, c)
    nodes[0] = dict() # simplify code for P0 and P1 computation
    
    sup = dict() # init ??

    for k in range(1, c+1):
        for u in set(nodes[k]) - special_nodes:
            candidates = set()
            P0 = {v for v in nodes[k-1] if v.lo is u}
            P1 = {v for v in nodes[k-1] if v.hi is u}

            for p in P0 | P1:
                q = p
                zp_q = zp(q, l(c, u))
                while u is zp_q:
                    q = sup[q]
                    zp_q = zp(q, l(c, u))
                
                candidates.add(zp_q) # zp can return None
                if zp_q is None:
                    sys.exit("zp_q is None")

            for p in P1:
                q = p
                zp_q1 = zp(q.hi, l(c, u))
                # print(zp_q1, q.hi.root, l(c,u), sup)
                while u is zp_q1  and q not in special_nodes:
                    q = sup[q]
                    zp_q1 = zp(q.hi, l(c, u))
                if u is not zp_q1:
                    candidates.add(zp_q1) # zp_q1 can be None
            
            if candidates: 
                sup[u] = argmin({v: len(F(v)) for v in candidates}) # KeyError if None in candidates
            
    print(sup)

memo_zp.clear()
compute_F(bdd_mono)
compute_zp(bdd_mono, len(bdd_vars_mono))
find_incl_mono(bdd_mono, len(bdd_vars_mono))

KeyError: <pyeda.boolalg.bdd.BDDNode object at 0x7f40318cfc10>