In [53]:
import networkx as nx
from manim import *
import math

config.media_width = "100%"
config.verbosity = "WARNING"

In [543]:
# Bitwise operations needed
def flip_bit(a, n):
	# a: number to modify
	# n: bit to flip (starting at 0)
	x = 1<<n
	return a^x

def count_ones(x):
	res = 0
	#f = "{0:0"+str(8)+"b}"

	while x!=0:
		#print(f.format(x))
		if (x%2 == 1):
			res = res+1
		x = x>>1
	#print(f'{f.format(x)} result = {res}')
	return res


In [55]:
# Test bitwise operations
assert(flip_bit(7, 0) == 6)
assert(flip_bit(7, 1) == 5)
assert(flip_bit(7, 2) == 3)
assert(flip_bit(7, 3) == 15)

In [272]:
def generic_run(selfi, network, **kwargs):
    null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
    delay = kwargs['delay'] if 'delay' in kwargs else 1
    (Network, layout) = network.get()

    unet = Network.to_undirected()
    network.init_computation()
    finished = False
    time = 0
    while not finished:
        print(f'Iteration {time}')
        if time == 0:
            G = Graph.from_networkx(Network, layout=layout, labels={i : Network.nodes[i]['weight']['value'] for i in Network}, 
                                    edge_config={(u,v): {'stroke_color': (WHITE if d['weight']['carry'] != null_label else GREEN),
                                                         'stroke_width': (6 if d['weight']['carry'] != null_label else 1)} 
                                                         for u, v, d in Network.edges(data = True)},
                                    vertex_config={u: {'fill_color': (WHITE if Network.nodes[u]['weight']['value'] != null_label else BLACK),
                                                       'stroke_width': 2, 'stroke_color': WHITE}
                                                   for u in Network.nodes})
            selfi.play(Create(G))
            counter = Text(str(time)).to_edge(UL)
            selfi.add(counter)
            try:
                list = network.additional_items_to_draw()
                for item in list:
                    if (item[0] == "RectangleAroundNodes"):
                        group = Group(*[G[x] for x in item[1]])
                        selfi.add(SurroundingRectangle(group))
            except Exception as e:
                print("Nothing to add:", e)
        else:
            selfi.wait(delay)
            G1 = Graph.from_networkx(Network, layout=layout, labels={i : Network.nodes[i]['weight']['value'] for i in Network}, 
                                    edge_config={(u,v): {'stroke_color': (WHITE if d['weight']['carry'] != null_label else GREEN),
                                                         'stroke_width': (6 if d['weight']['carry'] != null_label else 1)} 
                                                         for u, v, d in Network.edges(data = True)},
                                    vertex_config={u: {'fill_color': (WHITE if Network.nodes[u]['weight']['value'] != null_label else BLACK),
                                                       'stroke_width': 2, 'stroke_color': WHITE}
                                                   for u in Network.nodes})

            new_texts_from = {(u, v): Text(str(Network[u][v]['weight']['carry']), color=RED, font_size=32)
                        for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}
            new_texts_to = {(u, v): Text(str(Network[u][v]['weight']['carry']), color=RED, font_size=32)
                        for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label}
            selfi.play(Transform(G, G1), *[Transform(new_texts_from[(u, v)].move_to(G[u]), new_texts_to[(u, v)].move_to(G[v])) 
                                           for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label], 
                        path_arc = 70 * DEGREES, run_time=1.5*delay)
            selfi.play(Transform(counter, Text(str(time)).to_edge(UL)) ,*[FadeOut(new_texts_from[(u, v)]) for u, v, d in Network.edges(data=True) if d['weight']['carry'] != null_label], run_time=1.5*delay)
        finished = not network.compute_iteration()
        time = time+1
    selfi.wait(delay)


## Odd-Even Sort

In [None]:
class odd_even:
    def node_id(self, line, stage, comparator):
        assert(comparator>=0 and comparator < 4)
        assert(line < self.lines)
        assert(stage < self.lines)
        return (line<<int(math.log2(self.lines)*2+2)) + (stage<<int(math.log2(self.lines)+2)) + comparator
    
    def items_from_node_id(self, nodeid):
        return ((nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)*2+2)) >> int(math.log2(self.lines)*2+2),
                (nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)+2)) >> int(math.log2(self.lines)+2),
                nodeid&3)

    def input_node(self, line):
        return -(line+1)
    
    def inputs(self):
        return range(-1, -(self.lines+1), -1)
    
    def output_node(self, line):
        return -(line+1)-2*self.lines

    def outputs(self):
        return range(-1-2*self.lines, -(self.lines+1)-2*self.lines, -1)
    
    def __init__(self, nlines, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.lines = nlines
        self.N = 2*nlines-1

        self.Butterly = nx.DiGraph()
        self.layout = {}
        for stage in range(0, self.lines):
            for line in range((stage%2), self.lines-(stage%2), 2):
                #print(f'{stage}, {line}: {self.node_id(line, stage, 0)} {self.node_id(line, stage, 1)} {self.node_id(line+1, stage, 0)} {self.node_id(line+1, stage, 1)}')
                if stage%2 == 0:
                    self.OddEven.add_node(self.node_id(line, stage, 0), weight = {'value': self.null_label})
                    self.OddEven.add_node(self.node_id(line, stage, 1), weight = {'value': self.null_label})
                    self.OddEven.add_node(self.node_id(line+1, stage, 0), weight = {'value': self.null_label})
                    self.OddEven.add_node(self.node_id(line+1, stage, 1), weight = {'value': self.null_label})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.layout[self.node_id(line, stage, 0)] = [2*stage - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line, stage, 1)] = [2*stage+1 - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 0)] = [2*stage - self.lines+1, -(line+1 - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 1)] = [2*stage+1 - self.lines+1, -(line+1 - self.lines/2+1), 0]
                else:
                    self.OddEven.add_node(self.node_id(line, stage, 0), weight = {'value': self.null_label})
                    self.OddEven.add_node(self.node_id(line, stage, 1), weight = {'value': self.null_label})
                    self.OddEven.add_node(self.node_id(line+1, stage, 0), weight = {'value': self.null_label})
                    self.OddEven.add_node(self.node_id(line+1, stage, 1), weight = {'value': self.null_label})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0})
                    self.layout[self.node_id(line, stage, 0)] = [2*stage - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line, stage, 1)] = [2*stage+1 - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 0)] = [2*stage - self.lines+1, -(line+1 - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 1)] = [2*stage+1 - self.lines+1, -(line+1 - self.lines/2+1), 0]

        # Input and output nodes
        for line in range(0, self.lines):
            self.OddEven.add_node(self.input_node(line), weight = {'value': self.null_label})
            self.OddEven.add_edge(self.input_node(line), self.node_id(line, 0, 0), weight={"carry": self.null_label, "epoch": 0})
            self.OddEven.add_node(self.output_node(line), weight = {'value': self.null_label})
            if line == 0 or line == self.lines-1:
                self.OddEven.add_edge(self.node_id(line, self.lines-2, 1), self.output_node(line), weight={"carry": self.null_label, "epoch": 0})
            else:
                self.OddEven.add_edge(self.node_id(line, self.lines-1, 1), self.output_node(line), weight={"carry": self.null_label, "epoch": 0})
            self.layout[self.input_node(line)] = [0 - self.lines - 1+1, -(line - self.lines/2+1), 0]
            self.layout[self.output_node(line)] = [2*self.lines - self.lines - 2+3, -(line - self.lines/2+1), 0]

        # Inter-comparator connections
        for gap in range(0, self.lines-1):
            for line in range(1, self.lines-1):
                assert(self.node_id(line, gap, 1) in self.OddEven.nodes)
                assert(self.node_id(line, gap+1, 0) in self.OddEven.nodes)
                self.OddEven.add_edge(self.node_id(line, gap, 1), self.node_id(line, gap+1, 0), weight={"carry": self.null_label, "epoch": 0})
        for gap in range(0, self.lines-2, 2):
            assert(self.node_id(0, gap, 1) in self.OddEven.nodes)
            assert(self.node_id(0, gap+2, 0) in self.OddEven.nodes)
            self.OddEven.add_edge(self.node_id(0, gap, 1), self.node_id(0, gap+2, 0), weight={"carry": self.null_label, "epoch": 0})
            self.OddEven.add_edge(self.node_id(self.lines-1, gap, 1), self.node_id(self.lines-1, gap+2, 0), weight={"carry": self.null_label, "epoch": 0})

    def additional_items_to_draw(self):
        list = []
        list.append(('RectangleAroundNodes', [x for x in self.inputs()]))
        list.append(('RectangleAroundNodes', [x for x in self.outputs()]))
        return list

    def set_input(self, seq):
        self.seq = seq

    def get(self):
        return self.OddEven, self.layout

    def init_computation(self):
        for node, value in zip(self.inputs(), self.seq):
            self.OddEven.nodes[node]['weight']['value'] = str(value)
        self.epoch = 0

    def compute_iteration(self):
        some_updates = False
        for node in self.OddEven:
            if node in self.inputs():
                if self.epoch == 0:
                    for u, v, d in self.OddEven.out_edges(node, data=True):
                        assert(u==node)
                        if self.OddEven.nodes[node]['weight']['value'] != self.null_label:
                            some_updates = True
                            d['weight']['carry'] = self.OddEven.nodes[node]['weight']['value']
                            d['weight']['epoch'] = self.epoch+1
                else:
                    None
            else:
                if node in self.outputs():
                    in_a = None
                    for u, v, e in self.OddEven.in_edges(node, data=True):
                        in_a = e
                    assert(in_a != None)

                    if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                        some_updates = True
                        self.OddEven.nodes[node]['weight']['value'] = in_a['weight']['carry']
                        in_a['weight']['carry'] = self.null_label

                else:
                    line, stage, comp = self.items_from_node_id(node)
                    #print(f'node {node} ({line}, {stage}, {comp})')
                    if comp == 0: # First stage of the comparator
                        in_a = None
                        for u, v, e in self.OddEven.in_edges(node, data=True):
                            in_a = e

                        out = []
                        for u, v, e in self.OddEven.out_edges(node, data=True):
                            out.append(e)

                        assert(len(out) == 2)
                        assert(in_a != None)
                        if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                            some_updates = True
                            self.OddEven.nodes[node]['weight']['value'] = in_a['weight']['carry']
                            out[0]['weight']['carry'] = self.OddEven.nodes[node]['weight']['value']
                            out[0]['weight']['epoch'] = self.epoch+1
                            out[1]['weight']['carry'] = self.OddEven.nodes[node]['weight']['value']
                            out[1]['weight']['epoch'] = self.epoch+1
                            in_a['weight']['carry'] = self.null_label
                    else:
                        assert(comp==1)
                        out_a = None
                        for u, v, e in self.OddEven.out_edges(node, data=True):
                            out_a = e
                        ins = []
                        for u, v, e in self.OddEven.in_edges(node, data=True):
                            ins.append(e)
                        assert(len(ins) == 2)
                        assert(out_a != None)
                        if ins[0]['weight']['carry'] != self.null_label and ins[1]['weight']['carry'] != self.null_label and ins[0]['weight']['epoch'] <= self.epoch and ins[1]['weight']['epoch'] <= self.epoch:
                            #if stage%2==0:
                                if line%2==stage%2:
                                    # Up
                                    some_updates = True
                                    out_a['weight']['carry'] = str(min(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.OddEven.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label
                                else:
                                    # Down
                                    some_updates = True
                                    out_a['weight']['carry'] = str(max(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.OddEven.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label

        self.epoch = self.epoch + 1
        return some_updates

    def read_output(self):
        return [self.OddEven.nodes[x]['weight']['value'] for x in self.outputs()]

In [268]:
import random

for l in range(1, 6):
    lines = 1<<l
    net = odd_even(lines)
    for line in range(0, lines, 2):
        for stage in range(0, lines, 2):
            for comp in range(0, 4):
                id = net.node_id(line, stage, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line == l)
                assert(stage == s)
                assert(comp == c)
                id = net.node_id(line+1, stage, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line+1 == l)
                assert(stage == s)
                assert(comp == c)
                id = net.node_id(line, stage+1, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line == l)
                assert(stage+1 == s)
                assert(comp == c)
                id = net.node_id(line+1, stage+1, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line+1 == l)
                assert(stage+1 == s)
                assert(comp == c)

nleaves = 8
test_seq = [x for x in range(nleaves, 0, -1)]
for i in range(0, 1000):
    network = odd_even(nleaves, verbose = False)
    network.set_input(test_seq)
    network.init_computation()
    while network.compute_iteration():
        None

    res = network.read_output()
    #print(res)
    v = res[0]
    for x in range(1, nleaves):
        assert(int(v) <= int(res[x]))
        v = int(res[x])
    test_seq = [random.randint(-nleaves, nleaves*nleaves) for x in range(0, nleaves)]

print("All is OK!")


All is OK!


In [271]:
%%manim -qm ShowOddEven

config.frame_width = 20


class ShowOddEven(Scene):
    
    def construct(self):
        nbits = 8
        network = odd_even(nbits, delay=0.1)
        network.set_input([8, 7, 6, 5, 4, 3, 2 ,1])
        network.init_computation()
        generic_run(self, network)
        #print(network.read_output())

Iteration 0


                                                                                                                    

Iteration 1


                                                                                                                             

Iteration 2


                                                                                                                           

Iteration 3


                                                                                                                             

Iteration 4


                                                                                                                              

Iteration 5


                                                                                                                              

Iteration 6


                                                                                                                            

Iteration 7


                                                                                                                              

Iteration 8


                                                                                                                              

Iteration 9


                                                                                                                              

Iteration 10


                                                                                                                              

Iteration 11


                                                                                                                              

Iteration 12


                                                                                                                              

Iteration 13


                                                                                                                              

Iteration 14


                                                                                                                              

Iteration 15


                                                                                                                              

Iteration 16


                                                                                                                              

Iteration 17


                                                                                                                              

Iteration 18


                                                                                                                        

# Butterly Network (for DFT)

In [774]:
class butterfly:
    def node_id(self, line, stage, comparator):
        assert(comparator>=0 and comparator < 4)
        assert(line < self.lines)
        assert(stage < self.lines)
        return (line<<int(math.log2(self.lines)*2+2)) + (stage<<int(math.log2(self.lines)+2)) + comparator
    
    def items_from_node_id(self, nodeid):
        return ((nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)*2+2)) >> int(math.log2(self.lines)*2+2),
                (nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)+2)) >> int(math.log2(self.lines)+2),
                nodeid&3)

    def input_node(self, line):
        return -(line+1)
    
    def inputs(self):
        return range(-1, -(self.lines+1), -1)
    
    def output_node(self, line):
        return -(line+1)-2*self.lines

    def outputs(self):
        return range(-1-2*self.lines, -(self.lines+1)-2*self.lines, -1)

    def __init__(self, nlines, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.lines = nlines
        self.N = 2*nlines-1

        self.Butterfly = nx.DiGraph()

        node = 0
        self.layout = {}
        stage_x_starts = -self.lines/2
        for stage in range(0, int(math.log2(self.lines))):
            bit_to_flip = int(math.log2(self.lines)) - stage - 1
            stage_width = 1+bit_to_flip
            for line in range(0, self.lines):
                if line&(1<<bit_to_flip) == 0:
                    # starting line
                    node0 = self.node_id(line, stage, 0)
                    node1 = self.node_id(line, stage, 1)
                    node2 = self.node_id( flip_bit(line, bit_to_flip), stage, 0)
                    node3 = self.node_id( flip_bit(line, bit_to_flip), stage, 1)
                    self.Butterfly.add_node(node0, weight = {'value': self.null_label})
                    self.Butterfly.add_node(node1, weight = {'value': self.null_label})
                    self.Butterfly.add_node(node2, weight = {'value': self.null_label})
                    self.Butterfly.add_node(node3, weight = {'value': self.null_label})
                    self.Butterfly.add_edge(node0, node1, weight={"carry": self.null_label, "epoch": 0})
                    self.Butterfly.add_edge(node0, node3, weight={"carry": self.null_label, "epoch": 0})
                    self.Butterfly.add_edge(node2, node1, weight={"carry": self.null_label, "epoch": 0})
                    self.Butterfly.add_edge(node2, node3, weight={"carry": self.null_label, "epoch": 0})
                    self.layout[node0] = [ stage_x_starts,   - line + self.lines/2, 0 ]
                    self.layout[node1] = [ stage_x_starts+stage_width, - line + self.lines/2, 0 ]
                    self.layout[node2] = [ stage_x_starts,   - flip_bit(line, bit_to_flip) + self.lines/2, 0 ]
                    self.layout[node3] = [ stage_x_starts+stage_width, - flip_bit(line, bit_to_flip) + self.lines/2, 0 ]

            stage_x_starts = stage_x_starts + stage_width + 1

        minx = min([x[0] for x in self.layout.values()])
        maxx = max([x[0] for x in self.layout.values()])
        miny = min([x[1] for x in self.layout.values()])
        maxy = max([x[1] for x in self.layout.values()])

        # Input and output nodes
        for line in range(0, self.lines):
            self.Butterfly.add_node(self.input_node(line), weight = {'value': self.null_label})
            self.Butterfly.add_edge(self.input_node(line), self.node_id(line, 0, 0), weight={"carry": self.null_label, "epoch": 0})
            self.Butterfly.add_node(self.output_node(line), weight = {'value': self.null_label})
            self.Butterfly.add_edge(self.node_id(line, int(math.log2(self.lines))-1, 1), self.output_node(line), weight={"carry": self.null_label, "epoch": 0})
            self.layout[self.input_node(line)] = [minx-1, maxy - line, 0]
            self.layout[self.output_node(line)] = [maxx+1, maxy - line, 0]

        for stage in range(0, int(math.log2(self.lines))-1):
            for line in range(0, self.lines):
                source = self.node_id(line, stage, 1)
                dest   = self.node_id(line, stage+1, 0)
                self.Butterfly.add_edge(source, dest, weight={"carry": self.null_label, "epoch": 0})

        self.epoch = 0
        
    def get(self):
        return (self.Butterfly, self.layout)#, self.edges_labels_pos)

    def set_input(self, in_seq):
        self.in_seq = in_seq

    def init_computation(self):
        for n, v in zip(self.inputs(), self.in_seq):
            self.Butterfly.nodes[n]['weight']['value'] = str(v)

    def compute_iteration(self):
        None
        
    def read_output(self):
        None

In [768]:
%%manim -qm ShowButterfly

config.frame_width = 32
import random

class ShowButterfly(Scene):
    
    def construct(self):
        nleaves = 16
        network = butterfly(nleaves, delay=1)
        firsthalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        firsthalf.sort()
        secondhalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        secondhalf.sort(reverse=True)
        network.set_input(firsthalf + secondhalf)
        generic_run(self, network)
        #print(network.read_output())

Iteration 0


                                                                                                                    

Nothing to add: 'butterfly' object has no attribute 'additional_items_to_draw'


In [776]:
%%manim -qm BitonicMerge

config.frame_width = 32

class BitonicSequenceMerge(butterfly):
    def __init__(self, nlines, **kwargs):
        super(BitonicSequenceMerge, self).__init__(nlines, **kwargs)

    def compute_iteration(self):
        some_updates = False
        for node in self.Butterfly:
            if node in self.inputs():
                if self.epoch == 0:
                    for u, v, d in self.Butterfly.out_edges(node, data=True):
                        assert(u==node)
                        if self.Butterfly.nodes[node]['weight']['value'] != self.null_label:
                            some_updates = True
                            d['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            d['weight']['epoch'] = self.epoch+1
                else:
                    None
            else:
                if node in self.outputs():
                    in_a = None
                    for u, v, e in self.Butterfly.in_edges(node, data=True):
                        in_a = e
                    assert(in_a != None)

                    if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                        some_updates = True
                        self.Butterfly.nodes[node]['weight']['value'] = in_a['weight']['carry']
                        in_a['weight']['carry'] = self.null_label

                else:
                    line, stage, comp = self.items_from_node_id(node)
                    #print(f'node {node} ({line}, {stage}, {comp})')
                    if comp == 0: # First stage of the comparator
                        in_a = None
                        for u, v, e in self.Butterfly.in_edges(node, data=True):
                            in_a = e

                        out = []
                        for u, v, e in self.Butterfly.out_edges(node, data=True):
                            out.append(e)

                        assert(len(out) == 2)
                        assert(in_a != None)
                        if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                            some_updates = True
                            self.Butterfly.nodes[node]['weight']['value'] = in_a['weight']['carry']
                            out[0]['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            out[0]['weight']['epoch'] = self.epoch+1
                            out[1]['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            out[1]['weight']['epoch'] = self.epoch+1
                            in_a['weight']['carry'] = self.null_label
                    else:
                        assert(comp==1)
                        out_a = None
                        for u, v, e in self.Butterfly.out_edges(node, data=True):
                            out_a = e
                        ins = []
                        for u, v, e in self.Butterfly.in_edges(node, data=True):
                            ins.append(e)
                        assert(len(ins) == 2)
                        assert(out_a != None)
                        if ins[0]['weight']['carry'] != self.null_label and ins[1]['weight']['carry'] != self.null_label and ins[0]['weight']['epoch'] <= self.epoch and ins[1]['weight']['epoch'] <= self.epoch:
                            #if stage%2==0:
                                if line%2==stage%2:
                                    # Up
                                    some_updates = True
                                    out_a['weight']['carry'] = str(min(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.Butterfly.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label
                                else:
                                    # Down
                                    some_updates = True
                                    out_a['weight']['carry'] = str(max(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.Butterfly.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label

        self.epoch = self.epoch + 1
        return some_updates
    
class BitonicMerge(Scene):
    
    def construct(self):
        nleaves = 16
        network = BitonicSequenceMerge(nleaves, delay=1)
        firsthalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        firsthalf.sort()
        secondhalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        secondhalf.sort(reverse=True)
        network.set_input(firsthalf + secondhalf)
        generic_run(self, network)
        #print(network.read_output())

Iteration 0


                                                                                                                    

Nothing to add: 'BitonicSequenceMerge' object has no attribute 'additional_items_to_draw'
Iteration 1


                                                                                                                             

Iteration 2


                                                                                                                             

Iteration 3


                                                                                                                             

Iteration 4


                                                                                                                              

Iteration 5


                                                                                                                              

Iteration 6


                                                                                                                              

Iteration 7


                                                                                                                              

Iteration 8


                                                                                                                              

Iteration 9


                                                                                                                              

Iteration 10


                                                                                                                        

In [722]:
%%manim -qm Hypercube

config.frame_width = 10
import random

def draw_node(id, bits, pos):
    result = VGroup()
    f = "{0:0"+str(bits)+"b}"
    text = Text(f.format(id), color=BLACK, font='monospace', font_size=16).move_to(pos)
    #print(f'radius=max(text.width, text.height) = radius=max({text.width}, {text.height}) = {max(text.width, text.height)}')
    circle = Dot(point=pos, radius=max(text.width, text.height), color=WHITE)
    return result.add(circle, text)

class myarch(ArcBetweenPoints):
    def __init__(self, *args, **kwargs):
        super(myarch, self).__init__(angle=PI/3, *args, **kwargs)

class Hypercube(MovingCameraScene):

    def construct(self):
        edges_full = {}
        new_nodes = []

        nodes = {0: draw_node(0, 0, [0,0,0])}
        nodes[0].z_index = 1
        self.add(nodes[0])
        Title = Text("Hyercube").move_to(UP)

        shifts = [[0,0,0], [-2, 0, 0], [0,-2,0], [-1,-1,0], [-2, -4, 0], [-4, 1, 0], [0,0,0]]
        arches = [Line, Line, Line, Line, Line, Line, Line]
        widths = [0, 14, 18, 30, 32, 34, 34]
        for bits in range(1, 6):
            print(f'                          ITERATION {bits} (width={widths[bits]})')
            new_nodes = {i+len(nodes) : draw_node(i+len(nodes), bits, nodes[i].get_center()) for i in range(0, len(nodes))}
            print(f'new_nodes = {new_nodes}')
            new_sub_edges = {}
            new_cross_edges = {}
            f = "{0:0"+str(bits)+"b}"
            for src in range(0, len(nodes)+len(new_nodes)):
                for dst in range(src, len(nodes)+len(new_nodes)):
                    if not (src,dst) in edges_full:
                        if count_ones(src^dst)==1:
                            if not src in nodes:
                                new_sub_edges[(src, dst)] = arches[bits](new_nodes[src], new_nodes[dst])
                            else:
                                assert(src in nodes)
                                new_cross_edges[(src, dst)] = arches[bits](nodes[src], new_nodes[dst])
            self.play(*[n.animate.shift(shifts[bits]) for n in nodes.values()],
                        *[e.animate.shift(shifts[bits]) for e in edges_full.values()])
            update_labels = [draw_node(i, bits, nodes[i].get_center()) for i in range(0, len(nodes))]

            self.play(*[Transform(nodes[i], update_labels[i]) for i in range(0, len(nodes))]) 

            for i in range(0, len(nodes)):
                new_nodes[i+len(nodes)].move_to(nodes[i])
                new_nodes[i+len(nodes)].z_index = 1
            [new_sub_edges[(i,j)].put_start_and_end_on(new_nodes[i].get_center(),new_nodes[j].get_center()) for (i,j) in new_sub_edges]

            self.play(*[Create(n) for n in new_nodes.values()], *[Create(e) for e in new_sub_edges.values()], run_time=0.01)
            self.play(*[n.animate.shift([-2*i for i in shifts[bits]]) for n in new_nodes.values()],
                         *[e.animate.shift([-2*i for i in shifts[bits]]) for e in new_sub_edges.values()])
            if new_cross_edges:
                [new_cross_edges[(i,j)].put_start_and_end_on(nodes[i].get_center(),new_nodes[j].get_center()) for (i,j) in new_cross_edges]

                self.play(self.camera.frame.animate.set(width=widths[bits]),*[Create(e) for e in new_cross_edges.values()])
            self.wait(1)
            nodes = nodes | new_nodes
            edges_full = edges_full | new_cross_edges | new_sub_edges
            print(f'nodes: {nodes}')
            print(f'edges: {edges_full.keys()}')
        self.wait(2)


                          ITERATION 1 (width=14)
new_nodes = {1: VGroup(Dot, Text('1'))}


                                                                                                      

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1'))}
edges: dict_keys([(0, 1)])
                          ITERATION 2 (width=18)
new_nodes = {2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11'))}


                                                                                                              

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11'))}
edges: dict_keys([(0, 1), (0, 2), (1, 3), (2, 3)])
                          ITERATION 3 (width=30)
new_nodes = {4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111'))}


                                                                                                    

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11')), 4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111'))}
edges: dict_keys([(0, 1), (0, 2), (1, 3), (2, 3), (0, 4), (1, 5), (2, 6), (3, 7), (4, 5), (4, 6), (5, 7), (6, 7)])
                          ITERATION 4 (width=32)
new_nodes = {8: VGroup(Dot, Text('1000')), 9: VGroup(Dot, Text('1001')), 10: VGroup(Dot, Text('1010')), 11: VGroup(Dot, Text('1011')), 12: VGroup(Dot, Text('1100')), 13: VGroup(Dot, Text('1101')), 14: VGroup(Dot, Text('1110')), 15: VGroup(Dot, Text('1111'))}


                                                                                                    

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11')), 4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111')), 8: VGroup(Dot, Text('1000')), 9: VGroup(Dot, Text('1001')), 10: VGroup(Dot, Text('1010')), 11: VGroup(Dot, Text('1011')), 12: VGroup(Dot, Text('1100')), 13: VGroup(Dot, Text('1101')), 14: VGroup(Dot, Text('1110')), 15: VGroup(Dot, Text('1111'))}
edges: dict_keys([(0, 1), (0, 2), (1, 3), (2, 3), (0, 4), (1, 5), (2, 6), (3, 7), (4, 5), (4, 6), (5, 7), (6, 7), (0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (8, 9), (8, 10), (8, 12), (9, 11), (9, 13), (10, 11), (10, 14), (11, 15), (12, 13), (12, 14), (13, 15), (14, 15)])
                          ITERATION 5 (width=34)
new_nodes = {16: VGroup(Dot, Text('10000')), 17: VGroup(Dot, Text('10001')), 18: VGroup(Dot, Text('10010')), 19: VGroup(Dot, Text('10011')), 20: VGroup(Dot, Text('10100')),

                                                                                                    

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11')), 4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111')), 8: VGroup(Dot, Text('1000')), 9: VGroup(Dot, Text('1001')), 10: VGroup(Dot, Text('1010')), 11: VGroup(Dot, Text('1011')), 12: VGroup(Dot, Text('1100')), 13: VGroup(Dot, Text('1101')), 14: VGroup(Dot, Text('1110')), 15: VGroup(Dot, Text('1111')), 16: VGroup(Dot, Text('10000')), 17: VGroup(Dot, Text('10001')), 18: VGroup(Dot, Text('10010')), 19: VGroup(Dot, Text('10011')), 20: VGroup(Dot, Text('10100')), 21: VGroup(Dot, Text('10101')), 22: VGroup(Dot, Text('10110')), 23: VGroup(Dot, Text('10111')), 24: VGroup(Dot, Text('11000')), 25: VGroup(Dot, Text('11001')), 26: VGroup(Dot, Text('11010')), 27: VGroup(Dot, Text('11011')), 28: VGroup(Dot, Text('11100')), 29: VGroup(Dot, Text('11101')), 30: VGroup(Dot, Text('11110')), 31: VGroup(Dot, Text('11111'))}
edges: d