In [1]:
import networkx as nx
from manim import *
import math

config.media_width = "100%"
config.verbosity = "WARNING"

![ciao](media/tree.jpg)


In [2]:
def xor_op(a, b):
    return '1' if (int(a) ^ int(b)) else '0'

def and_op(a, b):
    return '1' if (int(a) & int(b)) else '0'

def or_op(a, b):
    return '1' if (int(a) | int(b)) else '0'



In [None]:

def generic_run(selfi, network, **kwargs):
    null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
    delay = kwargs['delay'] if 'delay' in kwargs else 1
    (Network, layout, label_edges) = network.get()

    unet = Network.to_undirected()
    network.init_computation()
    finished = False
    time = 0
    while not finished:
    #for time in range(0, 2*nbits):
        print(f'Iteration {time}')
        G = Graph.from_networkx(Network, layout=layout, labels={i : Network.nodes[i]['weight']['value'] for i in Network})
        selfi.add(G)
        if time == 0:
            counter = Text(str(time)).to_edge(UL)
            selfi.add(counter)
            texts = [Text(str(Network[u][v]['weight']['carry']), font_size=32, color=BLACK if Network[u][v]['weight']['carry']==null_label else RED).next_to(label_edges[(u, v)])
                        for u, v in Network.edges]
            for i in texts:
                selfi.add(i)
        else:
            new_texts = [Text(str(Network[u][v]['weight']['carry']), font_size=32, color=BLACK if Network[u][v]['weight']['carry']==null_label else RED).next_to(label_edges[(u, v)])
                        for u, v in Network.edges]
            selfi.play(Transform(counter, Text(str(time)).to_edge(UL)),*[Transform(t1, t2) for t1, t2 in zip(texts, new_texts)])
        selfi.wait(delay)
        finished = not network.compute_iteration()
        time = time+1
    selfi.wait(delay)



In [4]:
class linear_sum:
    def is_carry(self, v):
        return (v-1)%3==0
    
    def is_input(self, v):
        return not self.is_carry(v)
        

    def __init__(self, nbits, a, b, **kwargs):
        # nbits is the number of bits of the operation
        self.nbits = nbits

        f = "{0:0"+str(self.nbits)+"b}"
        self.a = f.format(a)
        self.b = f.format(b)

        self.null_label = '-'
        self.Array = nx.DiGraph()
        
        #        23 22 21 20 19 18 27 16 15 14 13 12 11 10 9  8  7  6  5  4  3  2  1  0
        #        i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i
        # pos_x -12-11-10-9  -8 -7 -6 -5 -4 -3 -2 -1 0  1  2  3  4  5  6  7  8  9 10 11
        # pos_y  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1
        for x in range(0, 3*self.nbits): # Input nodes
            #print(x)
            self.Array.add_node(x, weight = {'value': self.null_label})

        # Edges to compute the sum of input bits
        for x in range(1, 3*self.nbits, 3):
            #print(f'adding input edge ({x-1}, {x})')
            self.Array.add_edge(x-1, x, weight={"carry": self.null_label})
            #print(f'adding input edge ({x+1}, {x})')
            self.Array.add_edge(x+1, x, weight={"carry": self.null_label})

        # Edges to propagate the carry
        for x in range(1, 3*(self.nbits-1), 3):
            #print(f'adding carry edge ({x}, {x+3})')
            self.Array.add_edge(x, x+3, weight={"carry": self.null_label})

        self.layout = { x: [3*nbits/2-x, 0 if (x-1)%3==0 else -1, 0] for x in range(0, 3*self.nbits)}
        print(self.layout)

        # Compute the positions of the labels for edges
        #        23 22 21 20 19 18 27 16 15 14 13 12 11 10 9  8  7  6  5  4  3  2  1  0
        #        i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i
        # Solution 1 for input edges: put then at the mean between the coordinates of the endpoints
        # Solution 1 for carry edges: put label on top and between the x coordinates of the endpoints
        edge_labels_x = {u: (self.layout[u][0]+self.layout[v][0])/2 for u, v in self.Array.edges}
        edge_labels_y = {}
        for u,v in self.Array.edges:
            if self.is_input(u):
                edge_labels_y[u] = (self.layout[u][1]+self.layout[v][1])/2
            else:
                edge_labels_y[u] = 0.3
        self.labels_position = {(u, v): edge_labels_x[u]*RIGHT+edge_labels_y[u]*UP for u, v in self.Array.edges}

    def get(self):
        return (self.Array, self.layout, self.labels_position)

    def init_computation(self):
        print(self.a, self.b)

        for i in range(0, 3*self.nbits, 3):
            self.Array.nodes[i]['weight']['value'] = self.a[self.nbits-1-int(i/3)]
            self.Array.nodes[i+2]['weight']['value'] = self.b[self.nbits-1-int(i/3)]

    def compute_iteration(self):
        some_updates = False

        for node in range(3*self.nbits-1, -1, -1):

            if self.is_input(node):
                for u, v, d in self.Array.out_edges(node, data=True):
                    assert(u==node)
                    if self.Array.nodes[node]['weight']['value'] != self.null_label:
                        some_updates = True
                        #print(f'before: node {node}, edge ({u}, {v}), weigth {d} = {self.Array.nodes[node]['weight']['value']}')
                        d['weight']['carry'] = self.Array.nodes[node]['weight']['value']
                        self.Array.nodes[node]['weight']['value'] = self.null_label
                        #print(f'after:  node {node}, edge ({u}, {v}), weigth {d} = {self.Array.nodes[node]['weight']['value']}')
            else:
                #print("main computaiton")
                in_edges = [d for u, v, d in self.Array.in_edges(node, data=True)]
                in_edges_full = [(d, u, v) for u, v, d in self.Array.in_edges(node, data=True)]
                # check is the comutation is ready to start
                all_set = True
                for e in in_edges_full:
                    #print(f'              edge ({e[1]}, {e[2]}) -> {e[0]['weight']['carry']}')
                    all_set = all_set and (e[0]['weight']['carry'] != self.null_label)
                #print (f'everything ok for node {node}? {all_set}')
                if all_set:
                    some_updates = True
                    if (node==1): # it has only two incoming edges
                        #print("I'm here")
                        assert(len(in_edges) == 2)
                        a = in_edges[0]['weight']['carry']
                        b = in_edges[1]['weight']['carry']
                        in_edges[0]['weight']['carry'] = self.null_label
                        in_edges[1]['weight']['carry'] = self.null_label
                        self.Array.nodes[node]['weight']['value'] = xor_op(a, b)
                        for u, v, d in self.Array.out_edges(node, data=True):
                            d['weight']['carry'] = and_op(a, b)
                    else:
                        #print("I'm not here")
                        assert(len(in_edges) == 3)
                        a = in_edges[0]['weight']['carry']
                        b = in_edges[1]['weight']['carry']
                        c = in_edges[2]['weight']['carry']
                        in_edges[0]['weight']['carry'] = self.null_label
                        in_edges[1]['weight']['carry'] = self.null_label
                        in_edges[2]['weight']['carry'] = self.null_label
                        self.Array.nodes[node]['weight']['value'] = xor_op(xor_op(a, b), c)
                        for u, v, d in self.Array.out_edges(node, data=True):
                            d['weight']['carry'] = or_op(and_op(a, b), or_op(and_op(a, c), and_op(b, c)))
        return some_updates
    
    def read_output(self):
        s = ""
        n = 0
        i = 0
        for node in range(1, 3*(self.nbits), 3):
            s = s + self.Array.nodes[node]['weight']['value']
            if self.Array.nodes[node]['weight']['value'] != self.null_label:
                n = n + 2**i*int(self.Array.nodes[node]['weight']['value'])
            i = i + 1
        return s, n

In [None]:
# Testing networkw
nbits = 8
for i in range(0, 2**nbits):
	for j in range(0, 2**nbits):
		if (i+j < 2**nbits):
			network = linear_sum(nbits, i, j)
			network.init_computation()
			while network.compute_iteration():
				None
			assert( i+j == network.read_output()[1])

In [8]:
%%manim -qm ShowLinear

config.frame_width = 25


class ShowLinear(Scene):
    
    def construct(self):
        nbits = 8
        network = linear_sum(nbits, 32+84, 77+32, delay=0.2)
        generic_run(self, network)
        print(network.read_output())

{0: [12.0, -1, 0], 1: [11.0, 0, 0], 2: [10.0, -1, 0], 3: [9.0, -1, 0], 4: [8.0, 0, 0], 5: [7.0, -1, 0], 6: [6.0, -1, 0], 7: [5.0, 0, 0], 8: [4.0, -1, 0], 9: [3.0, -1, 0], 10: [2.0, 0, 0], 11: [1.0, -1, 0], 12: [0.0, -1, 0], 13: [-1.0, 0, 0], 14: [-2.0, -1, 0], 15: [-3.0, -1, 0], 16: [-4.0, 0, 0], 17: [-5.0, -1, 0], 18: [-6.0, -1, 0], 19: [-7.0, 0, 0], 20: [-8.0, -1, 0], 21: [-9.0, -1, 0], 22: [-10.0, 0, 0], 23: [-11.0, -1, 0]}
01110100 01101101
Iteration 0
Iteration 1


                                                                                        

Iteration 2


                                                                                        

Iteration 3


                                                                                        

Iteration 4


                                                                                        

Iteration 5


                                                                                        

Iteration 6


                                                                                         

Iteration 7


                                                                                         

Iteration 8


                                                                                         

Iteration 9


                                                                                         

('10000111', 225)


In [574]:
class tree_sum:
    def __init__(self, nbits, a, b, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.nbits = nbits
        self.N = 2*nbits-1

        f = "{0:0"+str(self.nbits)+"b}"
        self.a = f.format(a)
        self.b = f.format(b)

        self.Tree = nx.DiGraph()
        for x in range(0, self.N+2*self.nbits):
            #print(x)
            self.Tree.add_node(x, weight = {'value': self.null_label})

        for x in range(0,int(self.N/2)):
            # Edges parent to sons
            self.Tree.add_edge(x, x*2+1, weight={"carry": self.null_label, "dir": "down", "son": "left", "epoch": 0})
            self.Tree.add_edge(x, x*2+2, weight={"carry": self.null_label, "dir": "down", "son": "right", "epoch": 0})
            self.Tree.add_edge(x*2+1, x, weight={"carry": self.null_label, "all1s": self.null_label, "dir": "up", 'son': 'left', "epoch": 0})
            self.Tree.add_edge(x*2+2, x, weight={"carry": self.null_label, "all1s": self.null_label, "dir": "up", 'son': 'right', "epoch": 0})
            # Edges son to parent

        # Add edges for the real IO nodes
        for x in range(self.nbits-1, self.N):
            self.Tree.add_edge((x-self.nbits+1)*2+self.N,     x, weight={"carry": self.null_label, "dir": "up", "in": 'a', "epoch": 0})
            self.Tree.add_edge((x-self.nbits+1)*2+self.N + 1, x, weight={"carry": self.null_label, "dir": "up", 'in': 'b', "epoch": 0})
    
        self.pos_y = [1.5*math.floor(math.log2(y+1)) - 3 for y in range(0, self.N)]
        self.pos_y = self.pos_y + [1.5*math.floor(math.log2(self.N+1)) - 3 for y in range(self.N, self.N+2*self.nbits)]
        
        self.pos_x = {x: x - self.N - self.nbits for x in range(self.N, self.N+2*self.nbits)}

        self.pos_x = {x: (self.pos_x[2*(x-self.nbits+1)+self.N]+self.pos_x[2*(x-self.nbits+1)+self.N+1])/2 for x in range(self.nbits-1, self.N)}|self.pos_x

        for level in range(int(math.log2(self.nbits))-1, -1, -1):
            self.pos_x = {x: (self.pos_x[x*2+1]+self.pos_x[x*2+2])/2 for x in range(2**level-1, 2**level-1+2**level)}|self.pos_x

        self.layout = { v: [self.pos_x[v], self.pos_y[v], 0] for v in range(0, self.N+2*self.nbits)}

        self.edges_labels_pos = {}
        for u, v, d in self.Tree.edges(data=True):
            if u in range(self.N, self.N+2*self.nbits): # IO nodes
                self.edges_labels_pos[(u,v)] = [(self.layout[u][0]+self.layout[v][0])/2, 
                                                (self.layout[u][1]+self.layout[v][1])/2, 
                                                0 ]
            else:
                if d['weight']['dir'] == 'down':
                    if d['weight']['son'] == 'left':
                        self.edges_labels_pos[(u,v)] = [(self.layout[u][0]+self.layout[v][0])/2 + -0.2, 
                                                        (self.layout[u][1]+self.layout[v][1])/2 + 0.15, 
                                                        0 ]
                    if d['weight']['son'] == 'right':
                        self.edges_labels_pos[(u,v)] = [(self.layout[u][0]+self.layout[v][0])/2 - 0.4, 
                                                        (self.layout[u][1]+self.layout[v][1])/2 + 0.15, 
                                                        0 ]
                else:
                    if d['weight']['son'] == 'left':
                        self.edges_labels_pos[(u,v)] = [(self.layout[u][0]+self.layout[v][0])/2 - 0.8, 
                                                        (self.layout[u][1]+self.layout[v][1])/2 + 0, 
                                                        0 ]
                    if d['weight']['son'] == 'right':
                        self.edges_labels_pos[(u,v)] = [(self.layout[u][0]+self.layout[v][0])/2 + 0.25, 
                                                        (self.layout[u][1]+self.layout[v][1])/2 + 0, 
                                                        0 ]


        #self.visiting_list = self.visit_order(0) + [x for x in range(self.N, self.N+2*self.nbits)]
        #print(f'{self.visiting_list}') if self.verbose else None
        self.epoch = 0

    def visit_order(self, node):
        sons = [ v for u, v, e in self.Tree.out_edges(node, data=True) if e['weight']['dir'] == 'down']
        assert(len(sons) == 2 or len(sons) == 0)
        if not sons:
            return [node]
        print(f'visiting {sons[0]} and {sons[1]}') if self.verbose else None
        list = self.visit_order(sons[1]) + [node] + self.visit_order(sons[0])
        return list

    def is_input(self, v):
        return v in range(self.N, self.N+2*self.nbits)
    
    def is_output(self, v):
        return v in range(self.nbits-1, self.N)

    def is_internal_node(self, v):
        return v in range(0, self.nbits-1)

    def get(self):
        return (self.Tree, self.layout, self.edges_labels_pos)

    def init_computation(self):
        print(self.a, self.b) if self.verbose else None

        for x in range(self.N, self.N+2*self.nbits, 2): # more efficient that checking if a node is input
            # nodes are 2 by 2 holding one bit of a and 1 bit of b
            self.Tree.nodes[x]['weight']['value'] = self.a[self.nbits-int((x-(self.N))/2)-1]
            self.Tree.nodes[x+1]['weight']['value'] = self.b[self.nbits-int((x-(self.N))/2)-1]

    def compute_iteration(self):
        some_updates = False
        for node in self.Tree:
            if self.is_input(node):
                for u, v, d in self.Tree.out_edges(node, data=True):
                    assert(u==node)
                    if self.Tree.nodes[node]['weight']['value'] != self.null_label:
                        some_updates = True
                        d['weight']['carry'] = self.Tree.nodes[node]['weight']['value']
                        d['weight']['epoch'] = self.epoch+1
                        self.Tree.nodes[node]['weight']['value'] = self.null_label
            if self.is_output(node):
                in_a = None
                in_b = None
                in_top = None
                out_up = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        if e['weight']['in'] == "a":
                            in_a = e
                        if e['weight']['in'] == "b":
                            in_b = e
                    if e['weight']['dir'] == "down":
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                assert(in_a != None)
                assert(in_b != None)
                assert(in_top != None)
                assert(out_up != None)

                if in_a['weight']['carry'] != self.null_label and in_b['weight']['carry'] != self.null_label:
                    self.Tree.nodes[node]['weight']['value'] = xor_op(in_a['weight']['carry'], 
                                                                      in_b['weight']['carry'])
                    out_up['weight']['carry'] = and_op(in_a['weight']['carry'], in_b['weight']['carry'])
                    out_up['weight']['all1s'] = self.Tree.nodes[node]['weight']['value']
                    out_up['weight']['epoch'] = self.epoch+1
                    in_a['weight']['carry'] = self.null_label
                    in_b['weight']['carry'] = self.null_label
                    some_updates = True
                
                if in_top['weight']['carry'] != self.null_label and in_top['weight']['epoch'] <= self.epoch:
                    self.Tree.nodes[node]['weight']['value'] = xor_op(self.Tree.nodes[node]['weight']['value'], 
                                                                      in_top['weight']['carry'])
                    in_top['weight']['carry'] = self.null_label
                    some_updates = True

            if self.is_internal_node(node):
                in_left = None
                in_right = None
                in_top = None
                out_up = None
                out_left = None
                out_right = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        if e['weight']['son'] == 'left':
                            in_left = e
                        if e['weight']['son'] == 'right':
                            in_right = e
                    if e['weight']['dir'] == 'down':
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                    if (e['weight']['dir'] == 'down'):
                        if e['weight']['son'] == 'left':
                            out_left = e
                        if e['weight']['son'] == 'right':
                            out_right = e
                assert(in_left != None)
                assert(in_right != None)
                assert(in_top != None if node!=0 else True)
                assert(out_up != None if node!=0 else True)
                assert(out_left != None)
                assert(out_right != None)

                if in_left['weight']['carry'] != self.null_label and in_right['weight']['carry'] != self.null_label and in_left['weight']['epoch'] <= self.epoch and in_right['weight']['epoch'] <= self.epoch:
                    self.Tree.nodes[node]['weight']['value'] = in_left['weight']['all1s']
                    out_right['weight']['carry'] = in_left['weight']['carry']
                    if node != 0:
                        out_up['weight']['carry'] = or_op(in_right['weight']['carry'], 
                                                          and_op(in_right['weight']['all1s'], in_left['weight']['carry']))
                        out_up['weight']['all1s'] = and_op(in_left['weight']['all1s'], in_right['weight']['all1s'])
                        out_up['weight']['epoch'] = self.epoch+1
                    
                    some_updates = True
                    in_left['weight']['carry'] = self.null_label
                    in_right['weight']['carry'] = self.null_label
                
                if node != 0:
                    if in_top['weight']['carry'] != self.null_label and in_top['weight']['epoch'] <= self.epoch:
                        if self.Tree.nodes[node]['weight']['value'] == '1':
                            out_right['weight']['carry'] = in_top['weight']['carry']
                            out_right['weight']['epoch'] = self.epoch + 1
                        else:
                            out_right['weight']['carry'] = '0'
                            out_right['weight']['epoch'] = self.epoch + 1
                        out_left['weight']['carry'] = in_top['weight']['carry']
                        out_left['weight']['epoch'] = self.epoch + 1
                        in_top['weight']['carry'] = self.null_label
                        some_updates = True
        self.epoch = self.epoch + 1
        return some_updates

    def read_output(self):
        s = ""
        n = 0
        i = 0
        for node in range(self.nbits-1, self.N):
            #print(f"self.Tree.nodes[node]['weight']['value'] = {self.Tree.nodes[node]['weight']['value']}")
            s = s + self.Tree.nodes[node]['weight']['value']
            if self.Tree.nodes[node]['weight']['value'] != self.null_label:
                n = n + (2**i)*int(self.Tree.nodes[node]['weight']['value'])
            i = i+1
        return s, n

In [None]:
# Testing networkw
nbits = 8
for i in range(0, 2**nbits):
	for j in range(0, 2**nbits):
		if (i+j < 2**nbits):
			network = tree_sum(nbits, i, j, verbose = False)
			network.init_computation()
			while network.compute_iteration():
				None
			if i+j != network.read_output()[1]:
				print(f'error in checking that {i} + {j} = {network.read_output()[1]} ({i+j})')
			assert( i+j == network.read_output()[1])˜

In [584]:
%%manim -qm ShowTree

config.frame_width = 19
class ShowTree(Scene):
    
    def construct(self):
        nbits = 8
        network = tree_sum(nbits, 84+32, 77+32, null_label='-')
        generic_run(self, network, null_lable='-', delay=0.2)
        print(network.read_output())


Iteration 0
Iteration 1


                                                                              

Iteration 2


                                                                              

Iteration 3


                                                                              

Iteration 4


                                                                              

Iteration 5


                                                                              

Iteration 6


                                                                               

Iteration 7


                                                                                       

('10000111', 225)
