# Visualization of Computing Networks

In [1]:
import networkx as nx
from manim import *
import math

config.media_width = "100%"
config.verbosity = "WARNING"

![ciao](media/tree.jpg)


## Operations needed for the operations in the nodes of the graphs
The values stored in the graph are actually strings, so far. This means that the numeric operations need to be performed on strings.

In [2]:
def xor_op(a, b):
    return '1' if (int(a) ^ int(b)) else '0'

def and_op(a, b):
    return '1' if (int(a) & int(b)) else '0'

def or_op(a, b):
    return '1' if (int(a) | int(b)) else '0'

def add_op(a, b):
    assert(int(a) in range(0,10))
    assert(int(b) in range(0,10))
    c = int(a)+int(b)
    return str(c) if c<10 else str(c)[1]

def carry_op(a, b):
    assert(int(a) in range(0,10))
    assert(int(b) in range(0,10))
    c = int(a)+int(b)
    return '0' if c<10 else '1'

## Generic Run
This function takes a computing network and iterate on it to produce the visualization

In [3]:

def generic_run(selfi, network, **kwargs):
    null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
    delay = kwargs['delay'] if 'delay' in kwargs else 1
    (Network, layout) = network.get()

    unet = Network.to_undirected()
    network.init_computation()
    finished = False
    time = 0
    while not finished:
        print(f'Iteration {time}')
        if time == 0:
            selfi.G = Graph.from_networkx(Network, layout=layout, labels={i : Network.nodes[i]['weight']['value'] for i in Network}, 
                                    edge_config={(u,v): {'stroke_color': (WHITE if d['weight']['carry'] != null_label else d['weight']['color']),
                                                         'stroke_width': (6 if d['weight']['carry'] != null_label else 3)} 
                                                         for u, v, d in Network.edges(data = True)},
                                    vertex_config={u: {'fill_color': (WHITE if Network.nodes[u]['weight']['value'] != null_label else BLACK),
                                                       'stroke_width': 3, 'stroke_color': Network.nodes[u]['weight']['color']}
                                                   for u in Network.nodes})
            selfi.play(Create(selfi.G))
            counter = Text(str(time)).to_edge(UL)
            selfi.add(counter)
            try:
                network.additional_items_to_draw(selfi)
            except Exception as e:
                print("Nothing to add:", e)
        else:
            selfi.wait(delay)

            G1 = Graph.from_networkx(Network, layout=layout, labels={i : Network.nodes[i]['weight']['value'] for i in Network}, 
                                    edge_config={(u,v): {'stroke_color': (WHITE if d['weight']['carry'] != null_label else d['weight']['color']),
                                                         'stroke_width': (6 if d['weight']['carry'] != null_label else 3)} 
                                                         for u, v, d in Network.edges(data = True)},
                                    vertex_config={u: {'fill_color': (WHITE if Network.nodes[u]['weight']['value'] != null_label else BLACK),
                                                       'stroke_width': 3, 'stroke_color': Network.nodes[u]['weight']['color']}
                                                   for u in Network.nodes})

            try:
                new_texts_from = {(u, v): Text(str(network.string_for_edge(u,v)), color=RED, font_size=32)
                            for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}
                new_texts_to = {(u, v): Text(str(network.string_for_edge(u,v)), color=RED, font_size=32)
                            for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}
            except:
                new_texts_from = {(u, v): Text(str(Network[u][v]['weight']['carry']), color=RED, font_size=32)
                            for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}
                new_texts_to = {(u, v): Text(str(Network[u][v]['weight']['carry']), color=RED, font_size=32)
                            for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}

            try:
                network.additional_items_to_draw(selfi)
            except Exception as e:
                print("Nothing to add:", e)

            selfi.play(Transform(selfi.G, G1), *[Transform(new_texts_from[(u, v)].move_to(selfi.G[u]), new_texts_to[(u, v)].move_to(selfi.G[v])) 
                                           for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label], 
                        path_arc = 70 * DEGREES, run_time=1.5*delay)
            selfi.play(Transform(counter, Text(str(time)).to_edge(UL)) ,*[FadeOut(new_texts_from[(u, v)]) for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label], run_time=1.5*delay)
        finished = not network.compute_iteration()
        time = time+1
    selfi.wait(delay)


## Linear binary sum
This is a basic example of a computing network computing the sum of two binary numbers.
The linearity of the operation is apparent when comparing it with later binary sum implementation.

In [4]:
class linear_sum:
    def is_carry(self, v):
        return (v-1)%3==0
    
    def is_input(self, v):
        return not self.is_carry(v)
        

    def __init__(self, nbits, a, b, **kwargs):
        # nbits is the number of bits of the operation
        self.nbits = nbits

        f = "{0:0"+str(self.nbits)+"b}"
        self.a = f.format(a)
        self.b = f.format(b)

        self.null_label = '-'
        self.Array = nx.DiGraph()
        
        #        23 22 21 20 19 18 27 16 15 14 13 12 11 10 9  8  7  6  5  4  3  2  1  0
        #        i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i
        # pos_x -12-11-10-9  -8 -7 -6 -5 -4 -3 -2 -1 0  1  2  3  4  5  6  7  8  9 10 11
        # pos_y  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1  1  0  1
        for x in range(0, 3*self.nbits): # Input nodes
            #print(x)
            self.Array.add_node(x, weight = {'value': self.null_label, 'color': WHITE if x-1%3==0 else GREEN})

        # Edges to compute the sum of input bits
        for x in range(1, 3*self.nbits, 3):
            #print(f'adding input edge ({x-1}, {x})')
            self.Array.add_edge(x-1, x, weight={"carry": self.null_label, 'color': GREEN})
            #print(f'adding input edge ({x+1}, {x})')
            self.Array.add_edge(x+1, x, weight={"carry": self.null_label, 'color': GREEN})

        # Edges to propagate the carry
        for x in range(1, 3*(self.nbits-1), 3):
            #print(f'adding carry edge ({x}, {x+3})')
            self.Array.add_edge(x, x+3, weight={"carry": self.null_label, 'color': GREEN})

        self.layout = { x: [3*nbits/2-x, 0 if (x-1)%3==0 else (-1 if x%3==0 else -2), 0] for x in range(0, 3*self.nbits)}

        ### The following code is not needed anymore
        # Compute the positions of the labels for edges
        #        23 22 21 20 19 18 27 16 15 14 13 12 11 10 9  8  7  6  5  4  3  2  1  0
        #        i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i  i  c  i
        # Solution 1 for input edges: put then at the mean between the coordinates of the endpoints
        # Solution 1 for carry edges: put label on top and between the x coordinates of the endpoints
        # edge_labels_x = {u: (self.layout[u][0]+self.layout[v][0])/2 for u, v in self.Array.edges}
        # edge_labels_y = {}
        # for u,v in self.Array.edges:
        #     if self.is_input(u):
        #         edge_labels_y[u] = (self.layout[u][1]+self.layout[v][1])/2
        #     else:
        #         edge_labels_y[u] = 0.3
        # self.labels_position = {(u, v): edge_labels_x[u]*RIGHT+edge_labels_y[u]*UP for u, v in self.Array.edges}

    def additional_items_to_draw(self, selfi):
        if self.time == 0:
            input1 = Group(*[selfi.G[x] for x in range(0, 3*self.nbits, 3)])
            input2 = Group(*[selfi.G[x] for x in range(2, 3*self.nbits, 3)])
            output = Group(*[selfi.G[x] for x in range(1, 3*self.nbits, 3)])
            #t = Text("Ouptut", font_size=24).next_to(output, UP)
            #box = Group(output, t)
            selfi.add(  SurroundingRectangle(input1, color=YELLOW)
                        , SurroundingRectangle(input2, color=YELLOW)
                        , SurroundingRectangle(output, color=BLUE)
            )

    def get(self):
        return (self.Array, self.layout)#, self.labels_position)

    def init_computation(self):
        self.time = 0
        for i in range(0, 3*self.nbits, 3):
            self.Array.nodes[i]['weight']['value'] = self.a[self.nbits-1-int(i/3)]
            self.Array.nodes[i+2]['weight']['value'] = self.b[self.nbits-1-int(i/3)]

    def compute_iteration(self):
        some_updates = False

        for node in range(3*self.nbits-1, -1, -1):

            if self.is_input(node) and self.time == 0:
                for u, v, d in self.Array.out_edges(node, data=True):
                    assert(u==node)
                    if self.Array.nodes[node]['weight']['value'] != self.null_label:
                        some_updates = True
                        #print(f'before: node {node}, edge ({u}, {v}), weigth {d} = {self.Array.nodes[node]['weight']['value']}')
                        d['weight']['carry'] = self.Array.nodes[node]['weight']['value']
                        #self.Array.nodes[node]['weight']['value'] = self.null_label
                        #print(f'after:  node {node}, edge ({u}, {v}), weigth {d} = {self.Array.nodes[node]['weight']['value']}')
            if not self.is_input(node):
                #print("main computaiton")
                in_edges = [d for u, v, d in self.Array.in_edges(node, data=True)]
                in_edges_full = [(d, u, v) for u, v, d in self.Array.in_edges(node, data=True)]
                # check is the comutation is ready to start
                all_set = True
                for e in in_edges_full:
                    #print(f'              edge ({e[1]}, {e[2]}) -> {e[0]['weight']['carry']}')
                    all_set = all_set and (e[0]['weight']['carry'] != self.null_label)
                #print (f'everything ok for node {node}? {all_set}')
                if all_set:
                    some_updates = True
                    if (node==1): # it has only two incoming edges
                        assert(len(in_edges) == 2)
                        a = in_edges[0]['weight']['carry']
                        b = in_edges[1]['weight']['carry']
                        in_edges[0]['weight']['carry'] = self.null_label
                        in_edges[1]['weight']['carry'] = self.null_label
                        self.Array.nodes[node]['weight']['value'] = xor_op(a, b)
                        for u, v, d in self.Array.out_edges(node, data=True):
                            d['weight']['carry'] = and_op(a, b)
                    else:
                        #print("I'm not here")
                        assert(len(in_edges) == 3)
                        a = in_edges[0]['weight']['carry']
                        b = in_edges[1]['weight']['carry']
                        c = in_edges[2]['weight']['carry']
                        in_edges[0]['weight']['carry'] = self.null_label
                        in_edges[1]['weight']['carry'] = self.null_label
                        in_edges[2]['weight']['carry'] = self.null_label
                        self.Array.nodes[node]['weight']['value'] = xor_op(xor_op(a, b), c)
                        for u, v, d in self.Array.out_edges(node, data=True):
                            d['weight']['carry'] = or_op(and_op(a, b), or_op(and_op(a, c), and_op(b, c)))
        self.time = self.time + 1
        return some_updates
    
    def read_output(self):
        s = ""
        n = 0
        i = 0
        for node in range(1, 3*(self.nbits), 3):
            s = s + self.Array.nodes[node]['weight']['value']
            if self.Array.nodes[node]['weight']['value'] != self.null_label:
                n = n + 2**i*int(self.Array.nodes[node]['weight']['value'])
            i = i + 1
        return s, n

## Tesitng linear binary sum
This function takes the computing network above and test all possible inputs, quite naively, without visualization.

In [5]:
# Testing networkw
nbits = 8
for i in range(0, 2**nbits):
	for j in range(0, 2**nbits):
		if (i+j < 2**nbits):
			network = linear_sum(nbits, i, j)
			network.init_computation()
			while network.compute_iteration():
				None
			assert( i+j == network.read_output()[1])

## Visualize the Linear Binary Addition
The following code visualizes the linear sum network.

In [6]:
%%manim -qm ShowLinear

config.frame_width = 25


class ShowLinear(Scene):
    
    def construct(self):
        nbits = 8
        network = linear_sum(nbits, 32+84, 77+32, delay=1)
        generic_run(self, network)
        print(network.read_output())

Iteration 0


                                                                                                                  

Iteration 1


                                                                                                                           

Iteration 2


                                                                                                                           

Iteration 3


                                                                                                                           

Iteration 4


                                                                                                                            

Iteration 5


                                                                                                                            

Iteration 6


                                                                                                                            

Iteration 7


                                                                                                                            

Iteration 8


                                                                                                                            

Iteration 9


                                                                                                                      

('10000111', 225)


## Base tree for binary operations

This class implements a binary tree with an additional layer below the leaves for carrying the input values.
The input values will be placed by the start of the actual computation on the edges so that the tree can start processing to produce the results, which are stored at the leaves of the tree.

In [7]:
class base_tree_sum:
    def __init__(self, nbits, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.nbits = nbits
        self.N = 2*nbits-1

        self.Tree = nx.DiGraph()
        for x in range(0, self.N+2*self.nbits):
            #print(x)
            self.Tree.add_node(x, weight = {'value': self.null_label, 'color': GREEN})

        for x in range(0,int(self.N/2)):
            # Edges parent to sons
            self.Tree.add_edge(x, x*2+1, weight={"carry": self.null_label, "dir": "down", "son": "left", "epoch": 0, 'color': GREEN})
            self.Tree.add_edge(x, x*2+2, weight={"carry": self.null_label, "dir": "down", "son": "right", "epoch": 0, 'color': GREEN})
            self.Tree.add_edge(x*2+1, x, weight={"carry": self.null_label, "all1s": self.null_label, "dir": "up", 'son': 'left', "epoch": 0, 'color': GREEN})
            self.Tree.add_edge(x*2+2, x, weight={"carry": self.null_label, "all1s": self.null_label, "dir": "up", 'son': 'right', "epoch": 0, 'color': GREEN})
            # Edges son to parent

        # Add edges for the real IO nodes
        for x in range(self.nbits-1, self.N):
            self.Tree.add_edge((x-self.nbits+1)*2+self.N,     x, weight={"carry": self.null_label, "dir": "up", "in": 'a', "epoch": 0, 'color': GREEN})
            self.Tree.add_edge((x-self.nbits+1)*2+self.N + 1, x, weight={"carry": self.null_label, "dir": "up", 'in': 'b', "epoch": 0, 'color': GREEN})
    
        self.pos_y = [1.5*math.floor(math.log2(y+1)) - 3 for y in range(0, self.N)]
        self.pos_y = self.pos_y + [1.5*math.floor(math.log2(self.N+1)) - (3 if y%2==0 else 2) for y in range(self.N, self.N+2*self.nbits)]
        
        self.pos_x = {x: x - self.N - self.nbits for x in range(self.N, self.N+2*self.nbits)}

        self.pos_x = {x: (self.pos_x[2*(x-self.nbits+1)+self.N]+self.pos_x[2*(x-self.nbits+1)+self.N+1])/2 for x in range(self.nbits-1, self.N)}|self.pos_x

        for level in range(int(math.log2(self.nbits))-1, -1, -1):
            self.pos_x = {x: (self.pos_x[x*2+1]+self.pos_x[x*2+2])/2 for x in range(2**level-1, 2**level-1+2**level)}|self.pos_x

        self.layout = { v: [self.pos_x[v], self.pos_y[v], 0] for v in range(0, self.N+2*self.nbits)}

        self.epoch = 0

    def string_for_edge(self, u,v):
        if 'all1s' in self.Tree.get_edge_data(u,v)['weight']:
            if self.Tree.get_edge_data(u,v)['weight']['carry'] == self.null_label:
                return self.null_label
            else:
                if self.Tree.get_edge_data(u,v)['weight']['all1s'] != self.null_label:
                    return self.Tree.get_edge_data(u,v)['weight']['carry'] + self.Tree.get_edge_data(u,v)['weight']['all1s']
                else:
                    return self.Tree.get_edge_data(u,v)['weight']['carry']
        return self.Tree.get_edge_data(u,v)['weight']['carry']

    def additional_items_to_draw(self, selfi):
        if self.epoch==0:
            selfi.add(SurroundingRectangle(Group(*[selfi.G[x] for x in range(self.N, self.N+2*self.nbits, 2)]), color=YELLOW))
            selfi.add(SurroundingRectangle(Group(*[selfi.G[x] for x in range(self.N+1, self.N+2*self.nbits, 2)]), color=YELLOW))
            selfi.add(SurroundingRectangle(Group(*[selfi.G[x] for x in range(self.N-self.nbits, self.N)]), color=BLUE))

    def is_input(self, v):
        return v in range(self.N, self.N+2*self.nbits)
    
    def is_output(self, v):
        return v in range(self.nbits-1, self.N)

    def is_internal_node(self, v):
        return v in range(0, self.nbits-1)

    def get(self):
        return (self.Tree, self.layout)#, self.edges_labels_pos)



## Generic tree for sums
The algorithm to perform additions is gneric. Given a number in base $b>2$ adding two digits will always result in a carry that is $0$ of $1$. This means that the internal nodes of the tree performs always the same computations. At the leaves, the computations need to be adapted based on the base. If the base is 2, the logical operations (`xor` and `and`) are needed, but in the case of base 10 the addition is between the digits between 0 and 9 (this could be computed with a lookup table, in case we do not want to use the arithmetic operations, so to highliht the _procedural_ (a.k.a. dumb) nature of the computation).
The key observation to make this algorithm work is that the only way a carry, inserted before the least significan digit, can affect the carry after the most significant one, is whether all the digits in the number are $b-1$ (where $b$ is the base). An internal node, then, stores "1" if the "left" (least significant) tree is only made of $b-1$ digits. In this case, a carry coming from above is pushed to both sons, otherwise only on the lefr, where it will be _absorbed_ by the computation and would not propagate to the right.

In [8]:
class tree_sum(base_tree_sum):

    def __init__(self, nbits, numerical_base, sum_op, carry_op, **kwargs):
        super(tree_sum, self).__init__(nbits, **kwargs)
        self.sum_op = sum_op
        self.carry_op = carry_op
        self.max_digit = str(numerical_base-1)

    def init_computation(self):
        print(self.a, self.b) if self.verbose else None

        for x in range(self.N, self.N+2*self.nbits, 2): # more efficient that checking if a node is input
            # nodes are 2 by 2 holding one bit of a and 1 bit of b
            self.Tree.nodes[x]['weight']['value'] = self.a[self.nbits-int((x-(self.N))/2)-1]
            self.Tree.nodes[x+1]['weight']['value'] = self.b[self.nbits-int((x-(self.N))/2)-1]

    def compute_iteration(self):
        some_updates = False
        for node in self.Tree:
            if self.is_input(node) and self.epoch == 0:
                for u, v, d in self.Tree.out_edges(node, data=True):
                    assert(u==node)
                    if self.Tree.nodes[node]['weight']['value'] != self.null_label:
                        some_updates = True
                        d['weight']['carry'] = self.Tree.nodes[node]['weight']['value']
                        d['weight']['epoch'] = self.epoch+1
                        # self.Tree.nodes[node]['weight']['value'] = self.null_label
            if self.is_output(node):
                in_a = None
                in_b = None
                in_top = None
                out_up = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        if e['weight']['in'] == "a":
                            in_a = e
                        if e['weight']['in'] == "b":
                            in_b = e
                    if e['weight']['dir'] == "down":
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                assert(in_a != None)
                assert(in_b != None)
                assert(in_top != None)
                assert(out_up != None)

                if in_a['weight']['carry'] != self.null_label and in_b['weight']['carry'] != self.null_label:
                    self.Tree.nodes[node]['weight']['value'] = self.sum_op(in_a['weight']['carry'], 
                                                                      in_b['weight']['carry'])
                    out_up['weight']['carry'] = self.carry_op(in_a['weight']['carry'], in_b['weight']['carry'])
                    out_up['weight']['all1s'] = '1' if self.Tree.nodes[node]['weight']['value'] == self.max_digit else '0'
                    out_up['weight']['epoch'] = self.epoch+1
                    in_a['weight']['carry'] = self.null_label
                    in_b['weight']['carry'] = self.null_label
                    some_updates = True
                
                if in_top['weight']['carry'] != self.null_label and in_top['weight']['epoch'] <= self.epoch:
                    self.Tree.nodes[node]['weight']['value'] = self.sum_op(self.Tree.nodes[node]['weight']['value'], 
                                                                      in_top['weight']['carry'])
                    in_top['weight']['carry'] = self.null_label
                    some_updates = True

            if self.is_internal_node(node):
                in_left = None
                in_right = None
                in_top = None
                out_up = None
                out_left = None
                out_right = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        if e['weight']['son'] == 'left':
                            in_left = e
                        if e['weight']['son'] == 'right':
                            in_right = e
                    if e['weight']['dir'] == 'down':
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                    if (e['weight']['dir'] == 'down'):
                        if e['weight']['son'] == 'left':
                            out_left = e
                        if e['weight']['son'] == 'right':
                            out_right = e
                assert(in_left != None)
                assert(in_right != None)
                assert(in_top != None if node!=0 else True)
                assert(out_up != None if node!=0 else True)
                assert(out_left != None)
                assert(out_right != None)

                if in_left['weight']['carry'] != self.null_label and in_right['weight']['carry'] != self.null_label and in_left['weight']['epoch'] <= self.epoch and in_right['weight']['epoch'] <= self.epoch:
                    self.Tree.nodes[node]['weight']['value'] = in_left['weight']['all1s']
                    out_right['weight']['carry'] = in_left['weight']['carry']
                    out_right['weight']['epoch'] = self.epoch+1
                    if node != 0:
                        out_up['weight']['carry'] = or_op(in_right['weight']['carry'], 
                                                          and_op(in_right['weight']['all1s'], in_left['weight']['carry']))
                        out_up['weight']['all1s'] = and_op(in_left['weight']['all1s'], in_right['weight']['all1s'])
                        out_up['weight']['epoch'] = self.epoch+1
                    
                    some_updates = True
                    in_left['weight']['carry'] = self.null_label
                    in_right['weight']['carry'] = self.null_label
                
                if node != 0:
                    if in_top['weight']['carry'] != self.null_label and in_top['weight']['epoch'] <= self.epoch:
                        if self.Tree.nodes[node]['weight']['value'] == '1':
                            out_right['weight']['carry'] = in_top['weight']['carry']
                            out_right['weight']['epoch'] = self.epoch + 1
                        else:
                            out_right['weight']['carry'] = '0'
                            out_right['weight']['epoch'] = self.epoch + 1
                        out_left['weight']['carry'] = in_top['weight']['carry']
                        out_left['weight']['epoch'] = self.epoch + 1
                        in_top['weight']['carry'] = self.null_label
                        some_updates = True
        self.epoch = self.epoch + 1
        return some_updates

    def init_computation(self):
        print(self.a, self.b) if self.verbose else None

        for x in range(self.N, self.N+2*self.nbits, 2): # more efficient that checking if a node is input
            # nodes are 2 by 2 holding one bit of a and 1 bit of b
            self.Tree.nodes[x]['weight']['value'] = self.a[self.nbits-int((x-(self.N))/2)-1]
            self.Tree.nodes[x+1]['weight']['value'] = self.b[self.nbits-int((x-(self.N))/2)-1]


## Binary Sum Tree
What is specific here is the fact that the operations are `xor` and `and`, that $b-1=1$, and that the input is a binary string.

In [9]:
class bin_tree_sum(tree_sum):

    def __init__(self, nbits, **kwargs):
        super(bin_tree_sum, self).__init__(nbits, 2, xor_op, and_op, **kwargs)

    def set_input(self, a, b):
        f = "{0:0"+str(self.nbits)+"b}"
        self.a = f.format(a)
        self.b = f.format(b)

    def read_output(self):
        s = ""
        n = 0
        i = 0
        for node in range(self.nbits-1, self.N):
            #print(f"self.Tree.nodes[node]['weight']['value'] = {self.Tree.nodes[node]['weight']['value']}")
            s = s + self.Tree.nodes[node]['weight']['value']
            if self.Tree.nodes[node]['weight']['value'] != self.null_label:
                n = n + (2**i)*int(self.Tree.nodes[node]['weight']['value'])
            i = i+1
        return s, n

## Decimal Sum Tree
What is specific here is the fact that the operations are `+` and the carry is '0' or '1' depending whether the sum is snaller than 10 or no, that $b-1=9$, and that the input is just the numner padded to '0' on the left, in case.

In [10]:
class dec_tree_sum(tree_sum):

    def __init__(self, nbits, **kwargs):
        super(dec_tree_sum, self).__init__(nbits, 10, add_op, carry_op, **kwargs)

    def set_input(self, a, b):
        f = "{0:0"+str(self.nbits)+"}"
        self.a = f.format(a)
        self.b = f.format(b)


    def read_output(self):
        s = ""
        n = 0
        i = 0
        for node in range(self.nbits-1, self.N):
            #print(f"self.Tree.nodes[node]['weight']['value'] = {self.Tree.nodes[node]['weight']['value']}")
            s = s + self.Tree.nodes[node]['weight']['value']
            if self.Tree.nodes[node]['weight']['value'] != self.null_label:
                n = n + (10**i)*int(self.Tree.nodes[node]['weight']['value'])
            i = i+1
        return s, n

## Testing the binary sum

In [13]:
# Testing network
nbits = 16
for i in range(0, 2**nbits, 323):
	for j in range(0, 2**nbits, 252):
		if (i+j < 2**nbits):
			print(f'{i}+{j}', end='\r')
			network = bin_tree_sum(nbits, verbose = False)
			network.set_input(i,j)
			network.init_computation()
			while network.compute_iteration():
				None
			strres, numres = network.read_output()
			#print(f'error in checking that {i} + {j} = {numres} ({i+j}) [{strres}]')
			if i+j != numres:
				print(f'error in checking that {i} + {j} = {numres} ({i+j}) [{strres}]')
			assert( i+j == network.read_output()[1])

65246+25280

## Visualizing the Binary Sum on a tree

In [14]:
%%manim -qm ShowBinaryTreeSum

config.frame_width = 19
class ShowBinaryTreeSum(Scene):
    
    def construct(self):
        nbits = 8
        network = bin_tree_sum(nbits, null_label='-')
        network.set_input(64+32+8+16, 8)
        generic_run(self, network, null_lable='-', delay=1)
        print(network.read_output())


Iteration 0


                                                                                                        

Iteration 1


                                                                                                                 

Iteration 2


                                                                                                                 

Iteration 3


                                                                                                                 

Iteration 4


                                                                                                                  

Iteration 5


                                                                                                                  

Iteration 6


                                                                                                                  

Iteration 7


                                                                                                                  

Iteration 8


                                                                                                            

('00000001', 128)


## Testing the decimal sum

In [33]:
import random

# Testing networkw for DEC TREE SUM
# Testing network
nbits = 8
ntests = 60000
actual_tests = 0
for t in range(0, ntests):
	a = random.randint(0, 10**nbits)
	b = random.randint(0, 10**nbits)
	if (a+b < 10**nbits):
		actual_tests += 1
		print(f'test {actual_tests}/{ntests}: {a}+{b}', end='\r')
		network = dec_tree_sum(nbits, verbose = False)
		network.set_input(a,b)
		network.init_computation()
		while network.compute_iteration():
			None
		strres, numres = network.read_output()
		#print(f'error in checking that {a} + {b} = {numres} ({a+b}) [{strres}]')
		if a+b != numres:
			print(f'error in checking that {a} + {b} = {numres} ({a+b}) [{strres}]')
		assert( a+b == network.read_output()[1])

test 30219/60000: 764840+9030695270

## Visualizing the Sum of decimal numbers on a tree

In [34]:
%%manim -qm ShowDecTree

config.frame_width = 19
class ShowDecTree(Scene):
    
    def construct(self):
        nbits = 8
        network = dec_tree_sum(nbits, null_label='-')
        a = 3765708
        b = 45634234
        network.set_input(a, b)
        generic_run(self, network, null_lable='-', delay=1)
        print(f'result: {network.read_output()}, should be {a+b}')


Iteration 0


                                                                                                                  

Iteration 1


                                                                                                                           

Iteration 2


                                                                                                                           

Iteration 3


                                                                                                                           

Iteration 4


                                                                                                                            

Iteration 5


                                                                                                                            

Iteration 6


                                                                                                                            

Iteration 7


                                                                                                                            

Iteration 8


                                                                                                                      

result: ('24999394', 49399942), should be 49399942


## Binary Comparator in Log Time

In [35]:
class bin_comparator(base_tree_sum):

    def __init__(self, nbits, **kwargs):
        super(bin_comparator, self).__init__(nbits, **kwargs)

    def set_input(self, a, b):
        f = "{0:0"+str(self.nbits)+"b}"
        self.a = f.format(a)
        self.b = f.format(b)

    def init_computation(self):
        print(self.a, self.b) if self.verbose else None

        for x in range(self.N, self.N+2*self.nbits, 2): # more efficient that checking if a node is input
            # nodes are 2 by 2 holding one bit of a and 1 bit of b
            self.Tree.nodes[x]['weight']['value'] = self.a[self.nbits-int((x-(self.N))/2)-1]
            self.Tree.nodes[x+1]['weight']['value'] = self.b[self.nbits-int((x-(self.N))/2)-1]

    def compute_iteration(self):
        some_updates = False
        for node in self.Tree:
            if self.is_input(node) and self.epoch < 1:
                for u, v, d in self.Tree.out_edges(node, data=True):
                    assert(u==node)
                    if self.Tree.nodes[node]['weight']['value'] != self.null_label:
                        some_updates = True
                        d['weight']['carry'] = self.Tree.nodes[node]['weight']['value']
                        d['weight']['epoch'] = self.epoch+1
                        #self.Tree.nodes[node]['weight']['value'] = self.null_label
            if self.is_output(node):
                in_a = None
                in_b = None
                in_top = None
                out_up = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        if e['weight']['in'] == "a":
                            in_a = e
                        if e['weight']['in'] == "b":
                            in_b = e
                    if e['weight']['dir'] == "down":
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                assert(in_a != None)
                assert(in_b != None)
                assert(in_top != None)
                assert(out_up != None)

                if in_a['weight']['carry'] != self.null_label and in_b['weight']['carry'] != self.null_label:
                    if int(in_a['weight']['carry']) > int(in_b['weight']['carry']):
                        self.Tree.nodes[node]['weight']['value'] = '10' 
                    else:
                        if int(in_a['weight']['carry']) < int(in_b['weight']['carry']):
                            self.Tree.nodes[node]['weight']['value'] = '01'
                        else:
                            self.Tree.nodes[node]['weight']['value'] = '00'
                    out_up['weight']['carry'] = self.Tree.nodes[node]['weight']['value']
                    out_up['weight']['epoch'] = self.epoch+1
                    in_a['weight']['carry'] = self.null_label
                    in_b['weight']['carry'] = self.null_label
                    some_updates = True
                
            if self.is_internal_node(node):
                # The following code is just "standard", comparator does not need to descent the tree
                in_left = None
                in_right = None
                in_top = None
                out_up = None
                out_left = None
                out_right = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        if e['weight']['son'] == 'left':
                            in_left = e
                        if e['weight']['son'] == 'right':
                            in_right = e
                    if e['weight']['dir'] == 'down':
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                    if (e['weight']['dir'] == 'down'):
                        if e['weight']['son'] == 'left':
                            out_left = e
                        if e['weight']['son'] == 'right':
                            out_right = e
                assert(in_left != None)
                assert(in_right != None)
                assert(in_top != None if node!=0 else True)
                assert(out_up != None if node!=0 else True)
                assert(out_left != None)
                assert(out_right != None)

                if in_left['weight']['carry'] != self.null_label and in_right['weight']['carry'] != self.null_label and in_left['weight']['epoch'] <= self.epoch and in_right['weight']['epoch'] <= self.epoch:
                    if in_right['weight']['carry'] != '00':
                        self.Tree.nodes[node]['weight']['value'] = in_right['weight']['carry']
                    else: 
                        self.Tree.nodes[node]['weight']['value'] = in_left['weight']['carry']

                    if node != 0:
                        out_up['weight']['carry'] = self.Tree.nodes[node]['weight']['value']
                        out_up['weight']['epoch'] = self.epoch+1

                    some_updates = True
                    in_left['weight']['carry'] = self.null_label
                    in_right['weight']['carry'] = self.null_label


                
        self.epoch = self.epoch + 1
        return some_updates

    def init_computation(self):
        print(self.a, self.b) if self.verbose else None

        for x in range(self.N, self.N+2*self.nbits, 2): # more efficient that checking if a node is input
            # nodes are 2 by 2 holding one bit of a and 1 bit of b
            self.Tree.nodes[x]['weight']['value'] = self.a[self.nbits-int((x-(self.N))/2)-1]
            self.Tree.nodes[x+1]['weight']['value'] = self.b[self.nbits-int((x-(self.N))/2)-1]

    def read_output(self):
        return self.Tree.nodes[0]['weight']['value']


In [36]:
# Testing networkw for DEC TREE SUM
# Testing network
nbits = 8
for i in range(0, 2**nbits):
	for j in range(0, 2**nbits):
		if (i+j < 2**nbits):
			network = bin_comparator(nbits, verbose = False)
			network.set_input(i,j)
			network.init_computation()
			while network.compute_iteration():
				None
			strres = network.read_output()
			right_answer = '10' if i>j else ('00' if i==j else '01')
			#print(f'error in checking that {i}>{j} = {right_answer} [{strres}]')
			assert( right_answer == strres )
print("All OK!")

All OK!


In [37]:
%%manim -qm ShowBinaryComparatorTree

config.frame_width = 19
class ShowBinaryComparatorTree(Scene):
    
    def construct(self):
        nbits = 8
        network = bin_comparator(nbits, null_label='-')
        a = 184
        b = 226
        network.set_input(a, b)
        generic_run(self, network, null_lable='-', delay=1)
        print(f'result: {network.read_output()}, should be 01')


Iteration 0


                                                                                                        

Iteration 1


                                                                                                                 

Iteration 2


                                                                                                                           

Iteration 3


                                                                                                                           

Iteration 4


                                                                                                                  

Iteration 5


                                                                                                                      

result: 01, should be 01


### Playground to test stuff

In [38]:
%%manim -qm LabeledModifiedGraph
class LabeledModifiedGraph(Scene):
    def construct(self):
        vertices = [1, 2, 3, 4, 5, 6, 7, 8]
        edges = [(1, 7), (1, 8), (2, 3), (2, 4), (2, 5),
                 (2, 8), (3, 4), (6, 1), (6, 2),
                 (6, 3), (7, 2), (7, 4)]
        g = Graph(vertices, edges, layout="circular", layout_scale=3,
                  labels=True, vertex_config={7: {"fill_color": RED}},
                  edge_config={(1, 7): {"stroke_color": RED},
                               (2, 7): {"stroke_color": RED},
                               (4, 7): {"stroke_color": RED}})
        self.add(g)
        t = Text("ciao").next_to(g[3])
        self.play(Transform(Text("ciao").move_to(g[4]), Text("ciao").move_to(g[4])),
                  Transform(Text("ciao").move_to(g[5]), Text("ciao").next_to(g[1])), 
                  Transform(Text("ciao").move_to(g[3]), Text("ciao").next_to(g[8])), path_arc = 90*DEGREES)

                                                                                            