In [10]:
import networkx as nx
from manim import *
import math
import random

config.media_width = "100%"
config.verbosity = "WARNING"

# Prefix Sum

In [11]:
def generic_run(selfi, network, **kwargs):
    null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
    delay = kwargs['delay'] if 'delay' in kwargs else 1
    (Network, layout) = network.get()

    unet = Network.to_undirected()
    network.init_computation()
    finished = False
    time = 0
    while not finished:
        print(f'Iteration {time}')
        if time == 0:
            G = Graph.from_networkx(Network, layout=layout, labels={i : Network.nodes[i]['weight']['value'] for i in Network}, 
                                    edge_config={(u,v): {'stroke_color': (WHITE if d['weight']['carry'] != null_label else GREEN),
                                                         'stroke_width': (6 if d['weight']['carry'] != null_label else 1)} 
                                                         for u, v, d in Network.edges(data = True)},
                                    vertex_config={u: {'fill_color': (WHITE if Network.nodes[u]['weight']['value'] != null_label else BLACK),
                                                       'stroke_width': 2, 'stroke_color': WHITE}
                                                   for u in Network.nodes})
            selfi.add(G)
            counter = Text(str(time)).to_edge(UL)
            selfi.add(counter)
            try:
                list = network.additional_items_to_draw()
                for item in list:
                    if (item[0] == "RectangleAroundNodes"):
                        group = Group(*[G[x] for x in item[1]])
                        selfi.add(SurroundingRectangle(group))
            except Exception as e:
                print("Nothing to add:", e)

        else:
            selfi.wait(delay)
            G1 = Graph.from_networkx(Network, layout=layout, labels={i : Network.nodes[i]['weight']['value'] for i in Network}, 
                                    edge_config={(u,v): {'stroke_color': (WHITE if d['weight']['carry'] != null_label else GREEN),
                                                         'stroke_width': (6 if d['weight']['carry'] != null_label else 1)} 
                                                         for u, v, d in Network.edges(data = True)},
                                    vertex_config={u: {'fill_color': (WHITE if Network.nodes[u]['weight']['value'] != null_label else BLACK),
                                                       'stroke_width': 2, 'stroke_color': WHITE}
                                                   for u in Network.nodes})

            new_texts_from = {(u, v): Text(str(Network[u][v]['weight']['carry']), color=RED, font_size=32)
                        for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}
            new_texts_to = {(u, v): Text(str(Network[u][v]['weight']['carry']), color=RED, font_size=32)
                        for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}
            selfi.play(Transform(G, G1), *[Transform(new_texts_from[(u, v)].move_to(G[u]), new_texts_to[(u, v)].move_to(G[v])) 
                                           for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label], 
                        path_arc = 70 * DEGREES, run_time=1.5*delay)
            selfi.play(Transform(counter, Text(str(time)).to_edge(UL)) ,*[FadeOut(new_texts_from[(u, v)]) for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label])
        finished = not network.compute_iteration()
        time = time+1
    selfi.wait(delay)


In [24]:
class prefix_tree:
    def __init__(self, nbits, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.nbits = nbits
        self.N = 2*nbits-1

        self.Tree = nx.DiGraph()
        for x in range(0, self.N+self.nbits):
            #print(x)
            self.Tree.add_node(x, weight = {'value': self.null_label})

        for x in range(0,int(self.N/2)):
            # Edges parent to sons
            self.Tree.add_edge(x, x*2+1, weight={"carry": self.null_label, "dir": "down", "son": "left", "epoch": 0})
            self.Tree.add_edge(x, x*2+2, weight={"carry": self.null_label, "dir": "down", "son": "right", "epoch": 0})
            self.Tree.add_edge(x*2+1, x, weight={"carry": self.null_label, "all1s": self.null_label, "dir": "up", 'son': 'left', "epoch": 0})
            self.Tree.add_edge(x*2+2, x, weight={"carry": self.null_label, "all1s": self.null_label, "dir": "up", 'son': 'right', "epoch": 0})
            # Edges son to parent

        # Add edges for the real IO nodes
        for x in range(self.nbits-1, self.N):
            self.Tree.add_edge(x+self.nbits, x, weight={"carry": self.null_label, "dir": "up", "in": 'a', "epoch": 0})
    
        self.pos_y = [1.5*math.floor(math.log2(y+1)) - math.floor(math.log2(self.N-1)) for y in range(0, self.N)]
        self.pos_y = self.pos_y + [1.5*math.floor(math.log2(self.N+1)) - math.floor(math.log2(self.N-1)) for y in range(self.N, self.N+self.nbits)]
        
        self.pos_x = {x: x - self.N - self.nbits/2 for x in range(self.N, self.N+self.nbits)}

        self.pos_x = {x: self.pos_x[x+self.nbits] for x in range(self.nbits-1, self.N)}|self.pos_x

        for level in range(int(math.log2(self.nbits))-1, -1, -1):
            self.pos_x = {x: (self.pos_x[x*2+1]+self.pos_x[x*2+2])/2 for x in range(2**level-1, 2**level-1+2**level)}|self.pos_x

        self.layout = { v: [self.pos_x[v], self.pos_y[v], 0] for v in range(0, self.N+self.nbits)}

        self.epoch = 0

    def additional_items_to_draw(self):
        list = []
        list.append(('RectangleAroundNodes', [x for x in range(self.N, self.N+self.nbits)]))
        list.append(('RectangleAroundNodes', [x for x in range(self.nbits-1, self.N)]))
        return list

    def is_input(self, v):
        return v in range(self.N, self.N+self.nbits)
    
    def is_output(self, v):
        return v in range(self.nbits-1, self.N)

    def is_internal_node(self, v):
        return v in range(0, self.nbits-1)

    def get(self):
        return (self.Tree, self.layout)#, self.edges_labels_pos)

    def set_input(self, in_seq):
        self.in_seq = in_seq

    def init_computation(self):

        for x in range(self.N, self.N+self.nbits): # more efficient that checking if a node is input
            # nodes are 2 by 2 holding one bit of a and 1 bit of b
            self.Tree.nodes[x]['weight']['value'] = str(self.in_seq[x-self.N])

    def compute_iteration(self):
        some_updates = False
        for node in self.Tree.nodes:
            if self.is_input(node) and self.epoch == 0:
                for u, v, d in self.Tree.out_edges(node, data=True):
                    assert(u==node)
                    if self.Tree.nodes[node]['weight']['value'] != self.null_label:
                        some_updates = True
                        d['weight']['carry'] = self.Tree.nodes[node]['weight']['value']
                        d['weight']['epoch'] = self.epoch+1
                        #self.Tree.nodes[node]['weight']['value'] = self.null_label
            if self.is_output(node):
                in_a = None
                in_top = None
                out_up = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        in_a = e
                    if e['weight']['dir'] == "down":
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                assert(in_a != None)
                assert(in_top != None)
                assert(out_up != None)

                if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch: # The epoch check is not needed here
                    self.Tree.nodes[node]['weight']['value'] = in_a['weight']['carry']
                    out_up['weight']['carry'] = in_a['weight']['carry']
                    out_up['weight']['epoch'] = self.epoch+1
                    in_a['weight']['carry'] = self.null_label
                    some_updates = True

                if in_top['weight']['carry'] != self.null_label and in_top['weight']['epoch'] <= self.epoch:
                    self.Tree.nodes[node]['weight']['value'] = str(int(in_top['weight']['carry']) + int(self.Tree.nodes[node]['weight']['value']))
                    in_top['weight']['carry'] = self.null_label
                    some_updates = True

            if self.is_internal_node(node):
                in_left = None # the code to collect aliases is identical to the one for the tree_sum
                in_right = None
                in_top = None
                out_up = None
                out_left = None
                out_right = None
                for (u,v,e) in self.Tree.in_edges(node, data=True):
                    if (e['weight']['dir'] == 'up'):
                        if e['weight']['son'] == 'left':
                            in_left = e
                        if e['weight']['son'] == 'right':
                            in_right = e
                    if e['weight']['dir'] == 'down':
                        in_top = e

                for u, v, e in self.Tree.out_edges(node, data=True):
                    if e['weight']['dir'] == 'up':
                        out_up = e
                    if (e['weight']['dir'] == 'down'):
                        if e['weight']['son'] == 'left':
                            out_left = e
                        if e['weight']['son'] == 'right':
                            out_right = e
                assert(in_left != None)
                assert(in_right != None)
                assert(in_top != None if node!=0 else True)
                assert(out_up != None if node!=0 else True)
                assert(out_left != None)
                assert(out_right != None)

                if in_left['weight']['carry'] != self.null_label and in_right['weight']['carry'] != self.null_label and in_left['weight']['epoch'] <= self.epoch and in_right['weight']['epoch'] <= self.epoch:
                    out_right['weight']['carry'] = in_left['weight']['carry']
                    out_right['weight']['epoch'] = self.epoch+1
                    self.Tree.nodes[node]['weight']['value'] = str(int(in_left['weight']['carry']) + int(in_right['weight']['carry']))
                    if node != 0:
                        out_up['weight']['carry'] = str(int(in_left['weight']['carry']) + int(in_right['weight']['carry']))
                        out_up['weight']['epoch'] = self.epoch+1
                    in_left['weight']['carry'] = self.null_label
                    in_right['weight']['carry'] = self.null_label
                    some_updates = True
                
                if node != 0:
                    if in_top['weight']['carry'] != self.null_label and in_top['weight']['epoch'] <= self.epoch:
                        out_left['weight']['carry'] = in_top['weight']['carry']
                        out_right['weight']['carry'] = in_top['weight']['carry']
                        out_left['weight']['epoch'] = self.epoch + 1
                        out_right['weight']['epoch'] = self.epoch + 1
                        in_top['weight']['carry'] = self.null_label
                        some_updates = True

        self.epoch = self.epoch + 1
        return some_updates

    def read_output(self):
        res = []
        for node in range(self.nbits-1, self.N):
            res.append(int(self.Tree.nodes[node]['weight']['value']))
        return res


In [25]:
nleaves = 16
test_seq = [x for x in range(0, nleaves)]
for i in range(0, 1000):
	network = prefix_tree(nleaves, verbose = False)
	network.set_input(test_seq)
	network.init_computation()
	while network.compute_iteration():
		None
	res = network.read_output()
	assert(res[0] == test_seq[0])
	test_out = res[0]
	for x in range(1, nleaves):
		test_out = test_out + test_seq[x]
		assert(test_out == res[x])
	test_seq = [random.randint(-nleaves, nleaves*nleaves) for x in range(0, nleaves)]
print("All OK!")

All OK!


In [26]:
%%manim -qm ShowLinear

config.frame_width = 20

class ShowLinear(Scene):
    
    def construct(self):
        nleaves = 16
        network = prefix_tree(nleaves, delay=1)
        network.set_input([random.randint(0, nleaves) for x in range(0, nleaves)])
        generic_run(self, network)
        #print(network.read_output())

Iteration 0
Iteration 1


                                                                                                                           

Iteration 2


                                                                                                                           

Iteration 3


                                                                                                                           

Iteration 4


                                                                                                                            

Iteration 5


                                                                                                                            

Iteration 6


                                                                                                                            

Iteration 7


                                                                                                                            

Iteration 8


                                                                                                                            

Iteration 9


                                                                                                                            

Iteration 10


                                                                                                                      