In [6]:
import networkx as nx
from manim import *
import math
import vizutils as viz

config.media_width = "100%"
config.verbosity = "WARNING"

In [7]:
# Bitwise operations needed
def flip_bit(a, n):
	# a: number to modify
	# n: bit to flip (starting at 0)
	x = 1<<n
	return a^x

def count_ones(x):
	res = 0
	#f = "{0:0"+str(8)+"b}"

	while x!=0:
		#print(f.format(x))
		if (x%2 == 1):
			res = res+1
		x = x>>1
	#print(f'{f.format(x)} result = {res}')
	return res


In [8]:
# Test bitwise operations
assert(flip_bit(7, 0) == 6)
assert(flip_bit(7, 1) == 5)
assert(flip_bit(7, 2) == 3)
assert(flip_bit(7, 3) == 15)

In [9]:
%%manim -qm SimpleComparator

config.frame_width = 8

class SimpleComparatorImpl:
    def __init__(self, **kwargs):
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        self.Comparator = nx.DiGraph()
        self.layout = {}

        self.Comparator.add_node(0, weight = {'value': self.null_label, 'color':WHITE})
        self.Comparator.add_node(1, weight = {'value': self.null_label, 'color':GREEN})
        self.Comparator.add_node(2, weight = {'value': self.null_label, 'color':WHITE})
        self.Comparator.add_node(3, weight = {'value': self.null_label, 'color':RED})

        self.Comparator.add_edge(0, 1, weight={"carry": self.null_label, "epoch": 0, 'color':GREEN})
        self.Comparator.add_edge(0, 3, weight={"carry": self.null_label, "epoch": 0, 'color':RED})
        self.Comparator.add_edge(2, 1, weight={"carry": self.null_label, "epoch": 0, 'color':GREEN})
        self.Comparator.add_edge(2, 3, weight={"carry": self.null_label, "epoch": 0, 'color':RED})

        self.layout[0] = [-1,  1, 0]
        self.layout[1] = [ 1,  1, 0]
        self.layout[2] = [-1, -1, 0]
        self.layout[3] = [ 1, -1, 0]

    def additional_items_to_draw(self, selfi):
        if (self.epoch == 2):
            selfi.add(Tex(f"min({self.a}, {self.b})", font_size = 24).next_to(selfi.G[1]))
            selfi.add(Tex(f"max({self.a}, {self.b})", font_size = 24).next_to(selfi.G[3]))
        if (self.epoch == 0):
            selfi.add(Tex(f"broadcast({self.a})", font_size = 24).next_to(selfi.G[0], LEFT))
            selfi.add(Tex(f"broadcast({self.b})", font_size = 24).next_to(selfi.G[2], LEFT))

        list = []
        return list


    def set_input(self, a, b):
        self.a = a
        self.b = b

    def get(self):
        return self.Comparator, self.layout

    def init_computation(self):
        self.Comparator.nodes[0]['weight']['value'] = str(self.a)
        self.Comparator.nodes[2]['weight']['value'] = str(self.b)
        self.epoch = 0

    def compute_iteration(self):
        some_updates = False

        if self.epoch==2:
            return False

        if self.epoch == 0:
            for u,v,d in self.Comparator.out_edges(0, data=True):
                d['weight']['carry'] = self.Comparator.nodes[u]['weight']['value']
            for u,v,d in self.Comparator.out_edges(2, data=True):
                d['weight']['carry'] = self.Comparator.nodes[u]['weight']['value']
            some_updates = True
            self.epoch = self.epoch + 1
        else:
            in1 = []
            in3 = []
            for u,v,d in self.Comparator.in_edges(1, data=True):
                in1.append(int(d['weight']['carry']))
                d['weight']['carry'] = self.null_label
            for u,v,d in self.Comparator.in_edges(3, data=True):
                in3.append(int(d['weight']['carry']))
                d['weight']['carry'] = self.null_label
            self.Comparator.nodes[1]['weight']['value'] = str(min(in1))
            self.Comparator.nodes[3]['weight']['value'] = str(max(in3))
            some_updates = True
            self.epoch = self.epoch + 1

        return some_updates

class SimpleComparator(Scene):
    
    def construct(self):
        nbits = 8
        network = SimpleComparatorImpl()
        network.set_input(8, 5)
        network.init_computation()
        viz.generic_run(self, network)
        #print(network.read_output())

Iteration 0


                                                                                                       

Iteration 1


                                                                                                                          

Iteration 2


                                                                                                                    

## Odd-Even Sort

In [10]:
class odd_even:
    def node_id(self, line, stage, comparator):
        assert(comparator>=0 and comparator < 4)
        assert(line < self.lines)
        assert(stage < self.lines)
        return (line<<int(math.log2(self.lines)*2+2)) + (stage<<int(math.log2(self.lines)+2)) + comparator
    
    def items_from_node_id(self, nodeid):
        return ((nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)*2+2)) >> int(math.log2(self.lines)*2+2),
                (nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)+2)) >> int(math.log2(self.lines)+2),
                nodeid&3)

    def input_node(self, line):
        return -(line+1)
    
    def inputs(self):
        return range(-1, -(self.lines+1), -1)
    
    def output_node(self, line):
        return -(line+1)-2*self.lines

    def outputs(self):
        return range(-1-2*self.lines, -(self.lines+1)-2*self.lines, -1)
    
    def __init__(self, nlines, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.lines = nlines
        self.N = 2*nlines-1

        self.OddEven = nx.DiGraph()
        self.layout = {}
        for stage in range(0, self.lines):
            for line in range((stage%2), self.lines-(stage%2), 2):
                #print(f'{stage}, {line}: {self.node_id(line, stage, 0)} {self.node_id(line, stage, 1)} {self.node_id(line+1, stage, 0)} {self.node_id(line+1, stage, 1)}')
                if stage%2 == 0:
                    self.OddEven.add_node(self.node_id(line, stage, 0), weight = {'value': self.null_label, 'color':WHITE})
                    self.OddEven.add_node(self.node_id(line, stage, 1), weight = {'value': self.null_label, 'color':GREEN})
                    self.OddEven.add_node(self.node_id(line+1, stage, 0), weight = {'value': self.null_label, 'color':WHITE})
                    self.OddEven.add_node(self.node_id(line+1, stage, 1), weight = {'value': self.null_label, 'color':RED})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':GREEN})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':RED})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':RED})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':GREEN})
                    self.layout[self.node_id(line, stage, 0)] = [2*stage - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line, stage, 1)] = [2*stage+1 - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 0)] = [2*stage - self.lines+1, -(line+1 - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 1)] = [2*stage+1 - self.lines+1, -(line+1 - self.lines/2+1), 0]
                else:
                    self.OddEven.add_node(self.node_id(line, stage, 0), weight = {'value': self.null_label, 'color':WHITE})
                    self.OddEven.add_node(self.node_id(line, stage, 1), weight = {'value': self.null_label, 'color':GREEN})
                    self.OddEven.add_node(self.node_id(line+1, stage, 0), weight = {'value': self.null_label, 'color':WHITE})
                    self.OddEven.add_node(self.node_id(line+1, stage, 1), weight = {'value': self.null_label, 'color':RED})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':GREEN})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':RED})
                    self.OddEven.add_edge(self.node_id(line, stage, 0), self.node_id(line+1, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':RED})
                    self.OddEven.add_edge(self.node_id(line+1, stage, 0), self.node_id(line, stage, 1), weight={"carry": self.null_label, "epoch": 0, 'color':GREEN})
                    self.layout[self.node_id(line, stage, 0)] = [2*stage - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line, stage, 1)] = [2*stage+1 - self.lines+1, -(line - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 0)] = [2*stage - self.lines+1, -(line+1 - self.lines/2+1), 0]
                    self.layout[self.node_id(line+1, stage, 1)] = [2*stage+1 - self.lines+1, -(line+1 - self.lines/2+1), 0]

        # Input and output nodes
        for line in range(0, self.lines):
            self.OddEven.add_node(self.input_node(line), weight = {'value': self.null_label, 'color':WHITE})
            self.OddEven.add_edge(self.input_node(line), self.node_id(line, 0, 0), weight={"carry": self.null_label, "epoch": 0, 'color':WHITE})
            self.OddEven.add_node(self.output_node(line), weight = {'value': self.null_label, 'color':WHITE})
            if line == 0 or line == self.lines-1:
                self.OddEven.add_edge(self.node_id(line, self.lines-2, 1), self.output_node(line), weight={"carry": self.null_label, "epoch": 0, 'color':WHITE})
            else:
                self.OddEven.add_edge(self.node_id(line, self.lines-1, 1), self.output_node(line), weight={"carry": self.null_label, "epoch": 0, 'color':WHITE})
            self.layout[self.input_node(line)] = [0 - self.lines - 1+1, -(line - self.lines/2+1), 0]
            self.layout[self.output_node(line)] = [2*self.lines - self.lines - 2+3, -(line - self.lines/2+1), 0]

        # Inter-comparator connections
        for gap in range(0, self.lines-1):
            for line in range(1, self.lines-1):
                assert(self.node_id(line, gap, 1) in self.OddEven.nodes)
                assert(self.node_id(line, gap+1, 0) in self.OddEven.nodes)
                self.OddEven.add_edge(self.node_id(line, gap, 1), self.node_id(line, gap+1, 0), weight={"carry": self.null_label, "epoch": 0, 'color':WHITE})
        for gap in range(0, self.lines-2, 2):
            assert(self.node_id(0, gap, 1) in self.OddEven.nodes)
            assert(self.node_id(0, gap+2, 0) in self.OddEven.nodes)
            self.OddEven.add_edge(self.node_id(0, gap, 1), self.node_id(0, gap+2, 0), weight={"carry": self.null_label, "epoch": 0, 'color':WHITE})
            self.OddEven.add_edge(self.node_id(self.lines-1, gap, 1), self.node_id(self.lines-1, gap+2, 0), weight={"carry": self.null_label, "epoch": 0, 'color':WHITE})

    def additional_items_to_draw(self, selfi):
        if self.epoch == 0:
            group = Group(*[selfi.G[x] for x in self.inputs()])
            selfi.add(SurroundingRectangle(group))
            group = Group(*[selfi.G[x] for x in self.outputs()])
            selfi.add(SurroundingRectangle(group))

    def set_input(self, seq):
        self.seq = seq

    def get(self):
        return self.OddEven, self.layout

    def init_computation(self):
        for node, value in zip(self.inputs(), self.seq):
            self.OddEven.nodes[node]['weight']['value'] = str(value)
        self.epoch = 0

    def compute_iteration(self):
        some_updates = False
        for node in self.OddEven:
            if node in self.inputs():
                if self.epoch == 0:
                    for u, v, d in self.OddEven.out_edges(node, data=True):
                        assert(u==node)
                        if self.OddEven.nodes[node]['weight']['value'] != self.null_label:
                            some_updates = True
                            d['weight']['carry'] = self.OddEven.nodes[node]['weight']['value']
                            d['weight']['epoch'] = self.epoch+1
                else:
                    None
            else:
                if node in self.outputs():
                    in_a = None
                    for u, v, e in self.OddEven.in_edges(node, data=True):
                        in_a = e
                    assert(in_a != None)

                    if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                        some_updates = True
                        self.OddEven.nodes[node]['weight']['value'] = in_a['weight']['carry']
                        in_a['weight']['carry'] = self.null_label

                else:
                    line, stage, comp = self.items_from_node_id(node)
                    #print(f'node {node} ({line}, {stage}, {comp})')
                    if comp == 0: # First stage of the comparator
                        in_a = None
                        for u, v, e in self.OddEven.in_edges(node, data=True):
                            in_a = e

                        out = []
                        for u, v, e in self.OddEven.out_edges(node, data=True):
                            out.append(e)

                        assert(len(out) == 2)
                        assert(in_a != None)
                        if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                            some_updates = True
                            self.OddEven.nodes[node]['weight']['value'] = in_a['weight']['carry']
                            out[0]['weight']['carry'] = self.OddEven.nodes[node]['weight']['value']
                            out[0]['weight']['epoch'] = self.epoch+1
                            out[1]['weight']['carry'] = self.OddEven.nodes[node]['weight']['value']
                            out[1]['weight']['epoch'] = self.epoch+1
                            in_a['weight']['carry'] = self.null_label
                    else:
                        assert(comp==1)
                        out_a = None
                        for u, v, e in self.OddEven.out_edges(node, data=True):
                            out_a = e
                        ins = []
                        for u, v, e in self.OddEven.in_edges(node, data=True):
                            ins.append(e)
                        assert(len(ins) == 2)
                        assert(out_a != None)
                        if ins[0]['weight']['carry'] != self.null_label and ins[1]['weight']['carry'] != self.null_label and ins[0]['weight']['epoch'] <= self.epoch and ins[1]['weight']['epoch'] <= self.epoch:
                            #if stage%2==0:
                                if line%2==stage%2:
                                    # Up
                                    some_updates = True
                                    out_a['weight']['carry'] = str(min(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.OddEven.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label
                                else:
                                    # Down
                                    some_updates = True
                                    out_a['weight']['carry'] = str(max(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.OddEven.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label

        self.epoch = self.epoch + 1
        return some_updates

    def read_output(self):
        return [self.OddEven.nodes[x]['weight']['value'] for x in self.outputs()]

In [11]:
import random

for l in range(1, 6):
    lines = 1<<l
    net = odd_even(lines)
    for line in range(0, lines, 2):
        for stage in range(0, lines, 2):
            for comp in range(0, 4):
                id = net.node_id(line, stage, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line == l)
                assert(stage == s)
                assert(comp == c)
                id = net.node_id(line+1, stage, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line+1 == l)
                assert(stage == s)
                assert(comp == c)
                id = net.node_id(line, stage+1, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line == l)
                assert(stage+1 == s)
                assert(comp == c)
                id = net.node_id(line+1, stage+1, comp)
                (l, s, c) = net.items_from_node_id(id)
                assert(line+1 == l)
                assert(stage+1 == s)
                assert(comp == c)

nleaves = 8
test_seq = [x for x in range(nleaves, 0, -1)]
for i in range(0, 1000):
    network = odd_even(nleaves, verbose = False)
    network.set_input(test_seq)
    network.init_computation()
    while network.compute_iteration():
        None

    res = network.read_output()
    #print(res)
    v = res[0]
    for x in range(1, nleaves):
        assert(int(v) <= int(res[x]))
        v = int(res[x])
    test_seq = [random.randint(-nleaves, nleaves*nleaves) for x in range(0, nleaves)]

print("All is OK!")


All is OK!


In [13]:
%%manim -qm ShowOddEven

config.frame_width = 20


class ShowOddEven(Scene):
    
    def construct(self):
        nbits = 8
        network = odd_even(nbits, delay=0.1)
        network.set_input([8, 7, 6, 5, 4, 3, 2 ,1])
        network.init_computation()
        viz.generic_run(self, network)
        #print(network.read_output())

Iteration 0


                                                                                                                    

Iteration 1


                                                                                                                             

Iteration 2


                                                                                                                             

Iteration 3


                                                                                                                             

Iteration 4


                                                                                                                              

Iteration 5


                                                                                                                              

Iteration 6


                                                                                                                              

Iteration 7


                                                                                                                              

Iteration 8


                                                                                                                              

Iteration 9


                                                                                                                              

Iteration 10


                                                                                                                              

Iteration 11


                                                                                                                              

Iteration 12


                                                                                                                              

Iteration 13


                                                                                                                              

Iteration 14


                                                                                                                              

Iteration 15


                                                                                                                              

Iteration 16


                                                                                                                              

Iteration 17


                                                                                                                              

Iteration 18


                                                                                                                        

# Butterly Network (for DFT)

In [14]:
class butterfly:
    def node_id(self, line, stage, comparator):
        assert(comparator>=0 and comparator < 4)
        assert(line < self.lines)
        assert(stage < math.log2(self.lines))
        return (line<<int(math.log2(self.lines)*2+2)) + (stage<<int(math.log2(self.lines)+2)) + comparator
    
    def items_from_node_id(self, nodeid):
        return ((nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)*2+2)) >> int(math.log2(self.lines)*2+2),
                (nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)+2)) >> int(math.log2(self.lines)+2),
                nodeid&3)

    def input_node(self, line):
        return -(line+1)
    
    def inputs(self):
        return range(-1, -(self.lines+1), -1)
    
    def output_node(self, line):
        return -(line+1)-2*self.lines

    def outputs(self):
        return range(-1-2*self.lines, -(self.lines+1)-2*self.lines, -1)

    def __init__(self, nlines, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.lines = nlines
        self.N = 2*nlines-1

        self.Butterfly = nx.DiGraph()

        node = 0
        self.layout = {}
        total_width = 0
        for stage in range(0, int(math.log2(self.lines))):
            bit_to_flip = int(math.log2(self.lines)) - stage - 1
            stage_width = 1+bit_to_flip 
            total_width = total_width + stage_width
            
        stage_x_starts = -(total_width+2+2)/2
        for stage in range(0, int(math.log2(self.lines))):
            bit_to_flip = int(math.log2(self.lines)) - stage - 1
            stage_width = 1+bit_to_flip
            for line in range(0, self.lines):
                if line&(1<<bit_to_flip) == 0:
                    # starting line
                    node0 = self.node_id(line, stage, 0)
                    node1 = self.node_id(line, stage, 1)
                    node2 = self.node_id( flip_bit(line, bit_to_flip), stage, 0)
                    node3 = self.node_id( flip_bit(line, bit_to_flip), stage, 1)
                    self.Butterfly.add_node(node0, weight = {'value': self.null_label, 'color': WHITE})
                    self.Butterfly.add_node(node1, weight = {'value': self.null_label, 'color': GREEN})
                    self.Butterfly.add_node(node2, weight = {'value': self.null_label, 'color': WHITE})
                    self.Butterfly.add_node(node3, weight = {'value': self.null_label, 'color': RED})
                    self.Butterfly.add_edge(node0, node1, weight={"carry": self.null_label, "epoch": 0, 'color': GREEN})
                    self.Butterfly.add_edge(node0, node3, weight={"carry": self.null_label, "epoch": 0, 'color': RED})
                    self.Butterfly.add_edge(node2, node1, weight={"carry": self.null_label, "epoch": 0, 'color': GREEN})
                    self.Butterfly.add_edge(node2, node3, weight={"carry": self.null_label, "epoch": 0, 'color': RED})
                    self.layout[node0] = [ stage_x_starts,   - line + self.lines/2, 0 ]
                    self.layout[node1] = [ stage_x_starts+stage_width, - line + self.lines/2, 0 ]
                    self.layout[node2] = [ stage_x_starts,   - flip_bit(line, bit_to_flip) + self.lines/2, 0 ]
                    self.layout[node3] = [ stage_x_starts+stage_width, - flip_bit(line, bit_to_flip) + self.lines/2, 0 ]

            stage_x_starts = stage_x_starts + stage_width + 1

        minx = min([x[0] for x in self.layout.values()])
        maxx = max([x[0] for x in self.layout.values()])
        miny = min([x[1] for x in self.layout.values()])
        maxy = max([x[1] for x in self.layout.values()])

        # Input and output nodes
        for line in range(0, self.lines):
            self.Butterfly.add_node(self.input_node(line), weight = {'value': self.null_label, 'color': WHITE})
            self.Butterfly.add_edge(self.input_node(line), self.node_id(line, 0, 0), weight={"carry": self.null_label, "epoch": 0, 'color': YELLOW})
            self.Butterfly.add_node(self.output_node(line), weight = {'value': self.null_label, 'color': WHITE})
            self.Butterfly.add_edge(self.node_id(line, int(math.log2(self.lines))-1, 1), self.output_node(line), weight={"carry": self.null_label, "epoch": 0, 'color': YELLOW})
            self.layout[self.input_node(line)] = [minx-1, maxy - line, 0]
            self.layout[self.output_node(line)] = [maxx+1, maxy - line, 0]

        for stage in range(0, int(math.log2(self.lines))-1):
            for line in range(0, self.lines):
                source = self.node_id(line, stage, 1)
                dest   = self.node_id(line, stage+1, 0)
                self.Butterfly.add_edge(source, dest, weight={"carry": self.null_label, "epoch": 0, 'color': WHITE})

        self.epoch = 0
        
    def additional_items_to_draw(self, selfi):
        if self.epoch == 0:
            group = Group(*[selfi.G[x] for x in self.inputs()])
            selfi.add(SurroundingRectangle(group))
            group = Group(*[selfi.G[x] for x in self.outputs()])
            selfi.add(SurroundingRectangle(group))

    def get(self):
        return (self.Butterfly, self.layout)#, self.edges_labels_pos)

    def set_input(self, in_seq):
        self.in_seq = in_seq

    def init_computation(self):
        for n, v in zip(self.inputs(), self.in_seq):
            self.Butterfly.nodes[n]['weight']['value'] = str(v)

    def compute_iteration(self):
        None
        
    def read_output(self):
        res = []
        for n, v in zip(self.outputs(), self.in_seq):
            res.append(self.Butterfly.nodes[n]['weight']['value'])
        return res


In [15]:
%%manim -qm ShowButterfly

config.frame_width = 32
import random

class ShowButterfly(Scene):
    
    def construct(self):
        nleaves = 16
        network = butterfly(nleaves, delay=1)
        firsthalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        firsthalf.sort()
        secondhalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        secondhalf.sort(reverse=True)
        network.set_input(firsthalf + secondhalf)
        viz.generic_run(self, network)
        #print(network.read_output())

Iteration 0


                                                                                                                    

In [16]:
%%manim -qm BitonicMerge

config.frame_width = 32

class BitonicSequenceMerge(butterfly):
    def __init__(self, nlines, **kwargs):
        super(BitonicSequenceMerge, self).__init__(nlines, **kwargs)

    def compute_iteration(self):
        some_updates = False
        for node in self.Butterfly:
            if node in self.inputs():
                if self.epoch == 0:
                    for u, v, d in self.Butterfly.out_edges(node, data=True):
                        assert(u==node)
                        if self.Butterfly.nodes[node]['weight']['value'] != self.null_label:
                            some_updates = True
                            d['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            d['weight']['epoch'] = self.epoch+1
                else:
                    None
            else:
                if node in self.outputs():
                    in_a = None
                    for u, v, e in self.Butterfly.in_edges(node, data=True):
                        in_a = e
                    assert(in_a != None)

                    if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                        some_updates = True
                        self.Butterfly.nodes[node]['weight']['value'] = in_a['weight']['carry']
                        in_a['weight']['carry'] = self.null_label

                else:
                    line, stage, comp = self.items_from_node_id(node)
                    #print(f'node {node} ({line}, {stage}, {comp})')
                    if comp == 0: # First stage of the comparator
                        in_a = None
                        for u, v, e in self.Butterfly.in_edges(node, data=True):
                            in_a = e

                        out = []
                        for u, v, e in self.Butterfly.out_edges(node, data=True):
                            out.append(e)

                        assert(len(out) == 2)
                        assert(in_a != None)
                        if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                            some_updates = True
                            self.Butterfly.nodes[node]['weight']['value'] = in_a['weight']['carry']
                            out[0]['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            out[0]['weight']['epoch'] = self.epoch+1
                            out[1]['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            out[1]['weight']['epoch'] = self.epoch+1
                            in_a['weight']['carry'] = self.null_label
                    else:
                        assert(comp==1)
                        out_a = None
                        for u, v, e in self.Butterfly.out_edges(node, data=True):
                            out_a = e
                        ins = []
                        for u, v, e in self.Butterfly.in_edges(node, data=True):
                            ins.append(e)
                        assert(len(ins) == 2)
                        assert(out_a != None)
                        if ins[0]['weight']['carry'] != self.null_label and ins[1]['weight']['carry'] != self.null_label and ins[0]['weight']['epoch'] <= self.epoch and ins[1]['weight']['epoch'] <= self.epoch:
                            #if stage%2==0:
                                bit_to_flip = int(math.log2(self.lines)) - stage - 1
                                if line&(1<<bit_to_flip) == 0:
                                    # Up
                                    some_updates = True
                                    out_a['weight']['carry'] = str(min(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.Butterfly.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label
                                else:
                                    # Down
                                    some_updates = True
                                    out_a['weight']['carry'] = str(max(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.Butterfly.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label

        self.epoch = self.epoch + 1
        return some_updates

def test_bitonic_merge():
    nleaves = 16
    firsthalf = [x for x in range(0, nleaves>>1)]
    firsthalf.sort()
    secondhalf = [x for x in range(0, nleaves>>1)]
    secondhalf.sort(reverse=True)
    test_seq = firsthalf + secondhalf
    for i in range(0, 100):
        network = BitonicSequenceMerge(nleaves, verbose = False)
        network.set_input(test_seq)
        network.init_computation()
        while network.compute_iteration():
            None

        res = network.read_output()
        #print(res)
        v = res[0]
        for x in range(1, nleaves):
            assert(int(v) <= int(res[x]))
            v = int(res[x])
        firsthalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        firsthalf.sort()
        secondhalf = [random.randint(0, nleaves) for x in range(0, nleaves>>1)]
        secondhalf.sort(reverse=True)
        test_seq = firsthalf + secondhalf

    print("All is OK!")


class BitonicMerge(Scene):
    
    def construct(self):
        test_bitonic_merge()
        nleaves = 16
        network = BitonicSequenceMerge(nleaves, delay=1)
        firsthalf = [random.randint(0, 4*nleaves) for x in range(0, nleaves>>1)]
        #firsthalf = [x for x in range(0, nleaves>>1)]
        firsthalf.sort()
        secondhalf = [random.randint(0, 4*nleaves) for x in range(0, nleaves>>1)]
        #secondhalf = [x for x in range(0, nleaves>>1)]
        secondhalf.sort(reverse=True)
        inseq = firsthalf + secondhalf
        for i in range(0, 5):
            inseq.append(inseq.pop(0))
        network.set_input(inseq)
        viz.generic_run(self, network)
        #print(network.read_output())

All is OK!
Iteration 0


                                                                                                                    

Iteration 1


                                                                                                                             

Iteration 2


                                                                                                                             

Iteration 3


                                                                                                                             

Iteration 4


                                                                                                                              

Iteration 5


                                                                                                                              

Iteration 6


                                                                                                                              

Iteration 7


                                                                                                                              

Iteration 8


                                                                                                                              

Iteration 9


                                                                                                                              

Iteration 10


                                                                                                                        

In [17]:
%%manim -qm Hypercube

config.frame_width = 10
import random

def draw_node(id, bits, pos):
    result = VGroup()
    f = "{0:0"+str(bits)+"b}"
    text = Text(f.format(id), color=BLACK, font='monospace', font_size=16).move_to(pos)
    #print(f'radius=max(text.width, text.height) = radius=max({text.width}, {text.height}) = {max(text.width, text.height)}')
    circle = Dot(point=pos, radius=max(text.width, text.height), color=WHITE)
    return result.add(circle, text)

class myarch(ArcBetweenPoints):
    def __init__(self, *args, **kwargs):
        super(myarch, self).__init__(angle=PI/3, *args, **kwargs)

class Hypercube(MovingCameraScene):

    def construct(self):
        edges_full = {}
        new_nodes = []

        nodes = {0: draw_node(0, 0, [0,0,0])}
        nodes[0].z_index = 1
        self.add(nodes[0])
        Title = Text("Hyercube").move_to(UP)

        shifts = [[0,0,0], [-2, 0, 0], [0,-2,0], [-1,-1,0], [-2, -4, 0], [-4, 1.2, 0], [-9,-2,0]]
        arches = [Line, Line, Line, Line, Line, Line, Line, Line]
        colors = [WHITE, WHITE, RED, YELLOW, GREEN, BLUE, WHITE]
        widths = [0, 14, 18, 30, 32, 34, 34, 64]
        for bits in range(1, 6):
            print(f'                          ITERATION {bits} (width={widths[bits]})')
            new_nodes = {i+len(nodes) : draw_node(i+len(nodes), bits, nodes[i].get_center()) for i in range(0, len(nodes))}
            print(f'new_nodes = {new_nodes}')
            new_sub_edges = {}
            new_cross_edges = {}
            f = "{0:0"+str(bits)+"b}"
            for src in range(0, len(nodes)+len(new_nodes)):
                for dst in range(src, len(nodes)+len(new_nodes)):
                    if not (src,dst) in edges_full:
                        if count_ones(src^dst)==1:
                            if not src in nodes:
                                new_sub_edges[(src, dst)] = arches[bits](new_nodes[src], new_nodes[dst], color=colors[int(math.log2(src^dst)+1)], path_arc=PI/6)
                            else:
                                assert(src in nodes)
                                new_cross_edges[(src, dst)] = arches[bits](nodes[src], new_nodes[dst], color=colors[bits], path_arc=PI/6)
            self.play(*[n.animate.shift(shifts[bits]) for n in nodes.values()],
                        *[e.animate.shift(shifts[bits]) for e in edges_full.values()])
            update_labels = [draw_node(i, bits, nodes[i].get_center()) for i in range(0, len(nodes))]

            self.play(*[Transform(nodes[i], update_labels[i]) for i in range(0, len(nodes))]) 

            for i in range(0, len(nodes)):
                new_nodes[i+len(nodes)].move_to(nodes[i])
                new_nodes[i+len(nodes)].z_index = 1
            [new_sub_edges[(i,j)].put_start_and_end_on(new_nodes[i].get_center(),new_nodes[j].get_center()) for (i,j) in new_sub_edges]

            self.play(*[Create(n) for n in new_nodes.values()], *[Create(e) for e in new_sub_edges.values()], run_time=0.05)
            self.play(*[n.animate.shift([-2*i for i in shifts[bits]]) for n in new_nodes.values()],
                         *[e.animate.shift([-2*i for i in shifts[bits]]) for e in new_sub_edges.values()])
            if new_cross_edges:
                [new_cross_edges[(i,j)].put_start_and_end_on(nodes[i].get_center(),new_nodes[j].get_center()) for (i,j) in new_cross_edges]

                self.play(self.camera.frame.animate.set(width=widths[bits]),*[Create(e) for e in new_cross_edges.values()])
            self.wait(1)
            nodes = nodes | new_nodes
            edges_full = edges_full | new_cross_edges | new_sub_edges
            print(f'nodes: {nodes}')
            print(f'edges: {edges_full.keys()}')
        self.wait(2)


                          ITERATION 1 (width=14)
new_nodes = {1: VGroup(Dot, Text('1'))}


                                                                                              

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1'))}
edges: dict_keys([(0, 1)])
                          ITERATION 2 (width=18)
new_nodes = {2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11'))}


                                                                                                    

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11'))}
edges: dict_keys([(0, 1), (0, 2), (1, 3), (2, 3)])
                          ITERATION 3 (width=30)
new_nodes = {4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111'))}


                                                                                                               

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11')), 4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111'))}
edges: dict_keys([(0, 1), (0, 2), (1, 3), (2, 3), (0, 4), (1, 5), (2, 6), (3, 7), (4, 5), (4, 6), (5, 7), (6, 7)])
                          ITERATION 4 (width=32)
new_nodes = {8: VGroup(Dot, Text('1000')), 9: VGroup(Dot, Text('1001')), 10: VGroup(Dot, Text('1010')), 11: VGroup(Dot, Text('1011')), 12: VGroup(Dot, Text('1100')), 13: VGroup(Dot, Text('1101')), 14: VGroup(Dot, Text('1110')), 15: VGroup(Dot, Text('1111'))}


                                                                                                               

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11')), 4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111')), 8: VGroup(Dot, Text('1000')), 9: VGroup(Dot, Text('1001')), 10: VGroup(Dot, Text('1010')), 11: VGroup(Dot, Text('1011')), 12: VGroup(Dot, Text('1100')), 13: VGroup(Dot, Text('1101')), 14: VGroup(Dot, Text('1110')), 15: VGroup(Dot, Text('1111'))}
edges: dict_keys([(0, 1), (0, 2), (1, 3), (2, 3), (0, 4), (1, 5), (2, 6), (3, 7), (4, 5), (4, 6), (5, 7), (6, 7), (0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (8, 9), (8, 10), (8, 12), (9, 11), (9, 13), (10, 11), (10, 14), (11, 15), (12, 13), (12, 14), (13, 15), (14, 15)])
                          ITERATION 5 (width=34)
new_nodes = {16: VGroup(Dot, Text('10000')), 17: VGroup(Dot, Text('10001')), 18: VGroup(Dot, Text('10010')), 19: VGroup(Dot, Text('10011')), 20: VGroup(Dot, Text('10100')),

                                                                                                              

nodes: {0: VGroup(Dot, Text('0')), 1: VGroup(Dot, Text('1')), 2: VGroup(Dot, Text('10')), 3: VGroup(Dot, Text('11')), 4: VGroup(Dot, Text('100')), 5: VGroup(Dot, Text('101')), 6: VGroup(Dot, Text('110')), 7: VGroup(Dot, Text('111')), 8: VGroup(Dot, Text('1000')), 9: VGroup(Dot, Text('1001')), 10: VGroup(Dot, Text('1010')), 11: VGroup(Dot, Text('1011')), 12: VGroup(Dot, Text('1100')), 13: VGroup(Dot, Text('1101')), 14: VGroup(Dot, Text('1110')), 15: VGroup(Dot, Text('1111')), 16: VGroup(Dot, Text('10000')), 17: VGroup(Dot, Text('10001')), 18: VGroup(Dot, Text('10010')), 19: VGroup(Dot, Text('10011')), 20: VGroup(Dot, Text('10100')), 21: VGroup(Dot, Text('10101')), 22: VGroup(Dot, Text('10110')), 23: VGroup(Dot, Text('10111')), 24: VGroup(Dot, Text('11000')), 25: VGroup(Dot, Text('11001')), 26: VGroup(Dot, Text('11010')), 27: VGroup(Dot, Text('11011')), 28: VGroup(Dot, Text('11100')), 29: VGroup(Dot, Text('11101')), 30: VGroup(Dot, Text('11110')), 31: VGroup(Dot, Text('11111'))}
edges: d

In [18]:
class Batcher:
    def node_id(self, line, sub_stage, stage, comparator):
        assert(comparator>=0 and comparator < 4)
        if line >= self.lines:
            print(f'{line} < {self.lines}')
        assert(line < self.lines)
        assert(sub_stage < math.log2(self.lines))
        assert(stage < math.log2(self.lines))
        return (line<<int(math.log2(self.lines)*3+2)) + (sub_stage<<int(math.log2(self.lines)*2+2)) + (stage<<int(math.log2(self.lines)+2)) + comparator
    
    def items_from_node_id(self, nodeid):
        return ((nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)*3+2)) >> int(math.log2(self.lines)*3+2),
                (nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)*2+2)) >> int(math.log2(self.lines)*2+2),
                (nodeid &  ((1<<int(math.log2(self.lines)))-1)<<int(math.log2(self.lines)+2)) >> int(math.log2(self.lines)+2),
                nodeid&3)

    def input_node(self, line):
        return -(line+1)
    
    def inputs(self):
        return range(-1, -(self.lines+1), -1)
    
    def output_node(self, line):
        return -(line+1)-2*self.lines

    def outputs(self):
        return range(-1-2*self.lines, -(self.lines+1)-2*self.lines, -1)

    def __init__(self, nlines, **kwargs):
        #print(kwargs)
        self.null_label = kwargs['null_label'] if 'null_label' in kwargs else '-'
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else False

        # nbits is the number of bits of the operation
        # N is the number of nodes in the tree (2^n-1). Nodes in the tree will be from 0 to 2^n-2
        self.lines = nlines
        self.N = 2*nlines-1

        self.Butterfly = nx.DiGraph()

        node = 0
        self.layout = {}
        total_width = 0
        for stage in range(0, int(math.log2(self.lines))):
            sub_net_size = 1<<(stage+1)
            for sub_stage in range(0, int(math.log2(sub_net_size))):
                bit_to_flip = int(math.log2(sub_net_size)) - sub_stage - 1
                stage_width = 1+bit_to_flip
                total_width = total_width + stage_width + 1
        stage_x_starts = -((total_width)/2)+0.5

        for stage in range(0, int(math.log2(self.lines))):
            line_mask = (1<<stage+1)-1
            sub_net_size = 1<<(stage+1)
            number_of_networks = int(self.lines / sub_net_size)
            global_start = stage_x_starts
            for starting_line in range(0, self.lines, 1<<(stage+1)):
                stage_x_starts = global_start
                for sub_stage in range(0, int(math.log2(sub_net_size))):
                    bit_to_flip = int(math.log2(sub_net_size)) - sub_stage - 1
                    stage_width = 1+bit_to_flip
                    #print(f'line_mask={line_mask}, sub_stage={sub_stage}, starting_line={starting_line}, line_mask={line_mask}, number_of_networks={number_of_networks}, sub_net_size={sub_net_size}', end=' ')
                    for line in range(starting_line, starting_line+sub_net_size):
                        if (line&line_mask)&(1<<bit_to_flip) == 0:
                            # starting line
                            #print(f'line={line}, flip_bit(line, bit_to_flip)={flip_bit(line, bit_to_flip)}')
                            node0 = self.node_id(line, sub_stage, stage, 0)
                            node1 = self.node_id(line, sub_stage, stage, 1)
                            node2 = self.node_id( flip_bit(line, bit_to_flip), sub_stage, stage, 0)
                            node3 = self.node_id( flip_bit(line, bit_to_flip), sub_stage, stage, 1)
                            #print(f'nodes: {node0}, {node1}, {node2}, {node3}')
                            sub_net_index = math.floor(line / sub_net_size)
                            #print(f'line={line}, sub_net_index={sub_net_index}, line&line_mask={line&line_mask}')
                            assert(line < flip_bit(line, bit_to_flip))
                            node1_color = GREEN if sub_net_index%2==0 else RED
                            node3_color = GREEN if sub_net_index%2==1 else RED
                            self.Butterfly.add_node(node0, weight = {'value': self.null_label, 'color': WHITE})
                            self.Butterfly.add_node(node1, weight = {'value': self.null_label, 'color': node1_color})
                            self.Butterfly.add_node(node2, weight = {'value': self.null_label, 'color': WHITE})
                            self.Butterfly.add_node(node3, weight = {'value': self.null_label, 'color': node3_color})
                            self.Butterfly.add_edge(node0, node1, weight={"carry": self.null_label, "epoch": 0, 'color': node1_color})
                            self.Butterfly.add_edge(node0, node3, weight={"carry": self.null_label, "epoch": 0, 'color': node3_color})
                            self.Butterfly.add_edge(node2, node1, weight={"carry": self.null_label, "epoch": 0, 'color': node1_color})
                            self.Butterfly.add_edge(node2, node3, weight={"carry": self.null_label, "epoch": 0, 'color': node3_color})
                            self.layout[node0] = [ stage_x_starts,   - line + self.lines/2, 0 ]
                            self.layout[node1] = [ stage_x_starts+stage_width, - line + self.lines/2, 0 ]
                            self.layout[node2] = [ stage_x_starts,   - flip_bit(line, bit_to_flip) + self.lines/2, 0 ]
                            self.layout[node3] = [ stage_x_starts+stage_width, - flip_bit(line, bit_to_flip) + self.lines/2, 0 ]

                    stage_x_starts = stage_x_starts + stage_width + 1

                for sub_stage in range(0, int(math.log2(sub_net_size))-1):
                    for line in range(0, self.lines):
                        source = self.node_id(line, sub_stage, stage, 1)
                        dest   = self.node_id(line, sub_stage+1, stage, 0)
                        self.Butterfly.add_edge(source, dest, weight={"carry": self.null_label, "epoch": 0, 'color': GREEN})

        # Links between sub-networks
        for stage in range(0, int(math.log2(self.lines))-1):
            sub_net_size = 1<<(stage+1)
            for line in range(0, self.lines):
                #print(f'self.node_id({line}, {int(math.log2(sub_net_size))-1}, {stage}, 1) - self.node_id({line}, 0, {stage+1}, 0)', end=' ')
                source = self.node_id(line, int(math.log2(sub_net_size))-1, stage, 1)
                dest   = self.node_id(line, 0, stage+1, 0)
                assert(source in self.Butterfly.nodes())
                assert(dest in self.Butterfly.nodes())
                #print(f'{source} -> {dest}')
                self.Butterfly.add_edge(source, dest, weight={"carry": self.null_label, "epoch": 0, 'color': YELLOW})

        minx = min([x[0] for x in self.layout.values()])
        maxx = max([x[0] for x in self.layout.values()])
        miny = min([x[1] for x in self.layout.values()])
        maxy = max([x[1] for x in self.layout.values()])

        # Input and output nodes
        for line in range(0, self.lines):
            self.Butterfly.add_node(self.input_node(line), weight = {'value': self.null_label, 'color': WHITE})
            self.Butterfly.add_edge(self.input_node(line), self.node_id(line, 0, 0, 0), weight={"carry": self.null_label, "epoch": 0, 'color': YELLOW})
            self.Butterfly.add_node(self.output_node(line), weight = {'value': self.null_label, 'color': WHITE})
            self.Butterfly.add_edge(self.node_id(line, int(math.log2(self.lines))-1, int(math.log2(self.lines))-1, 1), self.output_node(line), weight={"carry": self.null_label, "epoch": 0, 'color': YELLOW})
            self.layout[self.input_node(line)] = [minx-1, maxy - line, 0]
            self.layout[self.output_node(line)] = [maxx+1, maxy - line, 0]


        self.epoch = 0
        
    def additional_items_to_draw(self, selfi):
        if self.epoch == 0:
            group = Group(*[selfi.G[x] for x in self.inputs()])
            selfi.add(SurroundingRectangle(group))
            group = Group(*[selfi.G[x] for x in self.outputs()])
            selfi.add(SurroundingRectangle(group))
            for stage in range(0, int(math.log2(self.lines))):
                line_mask = (1<<stage+1)-1
                sub_net_size = 1<<(stage+1)
                for starting_line in range(0, self.lines, 1<<(stage+1)):
                    group_list = []
                    for sub_stage in range(0, int(math.log2(sub_net_size))):
                        bit_to_flip = int(math.log2(sub_net_size)) - sub_stage - 1
                        stage_width = 1+bit_to_flip
                        #print(f'line_mask={line_mask}, sub_stage={sub_stage}, starting_line={starting_line}, line_mask={line_mask}, number_of_networks={number_of_networks}, sub_net_size={sub_net_size}', end=' ')
                        for line in range(starting_line, starting_line+sub_net_size):
                            if (line&line_mask)&(1<<bit_to_flip) == 0:
                                # starting line
                                #print(f'line={line}, flip_bit(line, bit_to_flip)={flip_bit(line, bit_to_flip)}')
                                node0 = self.node_id(line, sub_stage, stage, 0)
                                node1 = self.node_id(line, sub_stage, stage, 1)
                                node2 = self.node_id( flip_bit(line, bit_to_flip), sub_stage, stage, 0)
                                node3 = self.node_id( flip_bit(line, bit_to_flip), sub_stage, stage, 1)
                                group_list = group_list + [node0, node1, node2, node3]
                    print(f'Adding rectangle around {len(group_list)} nodes')
                    group = Group(*[selfi.G[x] for x in group_list])
                    selfi.add(SurroundingRectangle(group, color=DARK_BLUE, corner_radius=0.2))



    def get(self):
        return (self.Butterfly, self.layout)#, self.edges_labels_pos)

    def set_input(self, in_seq):
        self.in_seq = in_seq

    def init_computation(self):
        for n, v in zip(self.inputs(), self.in_seq):
            self.Butterfly.nodes[n]['weight']['value'] = str(v)

    def no_compute_iteration(self):
        None

    def compute_iteration(self):
        some_updates = False
        for node in self.Butterfly:
            if node in self.inputs():
                if self.epoch == 0:
                    for u, v, d in self.Butterfly.out_edges(node, data=True):
                        assert(u==node)
                        if self.Butterfly.nodes[node]['weight']['value'] != self.null_label:
                            some_updates = True
                            d['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            d['weight']['epoch'] = self.epoch+1
                else:
                    None
            else:
                if node in self.outputs():
                    in_a = None
                    for u, v, e in self.Butterfly.in_edges(node, data=True):
                        in_a = e
                    assert(in_a != None)

                    if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                        some_updates = True
                        self.Butterfly.nodes[node]['weight']['value'] = in_a['weight']['carry']
                        in_a['weight']['carry'] = self.null_label

                else:
                    line, sub_stage, stage, comp = self.items_from_node_id(node)
                    #print(f'node {node} ({line}, {stage}, {comp})')
                    if comp == 0: # First stage of the comparator
                        in_a = None
                        for u, v, e in self.Butterfly.in_edges(node, data=True):
                            in_a = e

                        out = []
                        for u, v, e in self.Butterfly.out_edges(node, data=True):
                            out.append(e)

                        assert(len(out) == 2)
                        assert(in_a != None)
                        if in_a['weight']['carry'] != self.null_label and in_a['weight']['epoch'] <= self.epoch:
                            some_updates = True
                            self.Butterfly.nodes[node]['weight']['value'] = in_a['weight']['carry']
                            out[0]['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            out[0]['weight']['epoch'] = self.epoch+1
                            out[1]['weight']['carry'] = self.Butterfly.nodes[node]['weight']['value']
                            out[1]['weight']['epoch'] = self.epoch+1
                            in_a['weight']['carry'] = self.null_label
                    else:
                        assert(comp==1)
                        out_a = None
                        for u, v, e in self.Butterfly.out_edges(node, data=True):
                            out_a = e
                        ins = []
                        for u, v, e in self.Butterfly.in_edges(node, data=True):
                            ins.append(e)
                        assert(len(ins) == 2)
                        assert(out_a != None)
                        if ins[0]['weight']['carry'] != self.null_label and ins[1]['weight']['carry'] != self.null_label and ins[0]['weight']['epoch'] <= self.epoch and ins[1]['weight']['epoch'] <= self.epoch:
                            #if stage%2==0:
                                line_mask = (1<<stage+1)-1
                                sub_net_size = 1<<(stage+1)
                                sub_net_index = math.floor(line/sub_net_size)
                                bit_to_flip = int(math.log2(sub_net_size)) - sub_stage - 1
                                #comp_direction_up = True # up or down
                                if (line_mask&line)&(1<<bit_to_flip) == 0:
                                    # Up
                                    some_updates = True
                                    if  sub_net_index%2 == 0:
                                        out_a['weight']['carry'] = str(min(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    else:
                                        out_a['weight']['carry'] = str(max(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.Butterfly.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label
                                else:
                                    # Down
                                    some_updates = True
                                    if  sub_net_index%2 == 0:
                                        out_a['weight']['carry'] = str(max(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    else:
                                        out_a['weight']['carry'] = str(min(int(ins[0]['weight']['carry']), int(ins[1]['weight']['carry'])))
                                    self.Butterfly.nodes[node]['weight']['value'] = out_a['weight']['carry']
                                    out_a['weight']['epoch'] = self.epoch+1
                                    ins[0]['weight']['carry'] = self.null_label
                                    ins[1]['weight']['carry'] = self.null_label

        self.epoch = self.epoch + 1
        return some_updates
        
    def read_output(self):
        None
        res = []
        for n, v in zip(self.outputs(), self.in_seq):
            res.append(self.Butterfly.nodes[n]['weight']['value'])
        return res


In [20]:
%%manim -qm ShowBatcher8

config.frame_width = 32
import random

class ShowBatcher8(MovingCameraScene):
    def test():
        for l in range(1, 6):
            lines = 1<<l
            net = Batcher(lines)
            for line in range(0, lines, 2):
                for sub_stage in range(0, int(math.log2(lines)), 1):
                    for stage in range(0, int(math.log2(lines)), 1):
                        for comp in range(0, 4):
                            id = net.node_id(line, sub_stage,stage, comp)
                            (l, ss, s, c) = net.items_from_node_id(id)
                            assert(line == l)
                            assert(stage == s)
                            assert(sub_stage == ss)
                            assert(comp == c)
                            id = net.node_id(line+1, sub_stage+1, stage, comp)
                            (l, ss, s, c) = net.items_from_node_id(id)
                            assert(line+1 == l)
                            assert(stage == s)
                            assert(sub_stage+1 == ss)
                            assert(comp == c)
                            id = net.node_id(line, sub_stage, stage+1, comp)
                            (l, ss, s, c) = net.items_from_node_id(id)
                            assert(line == l)
                            assert(stage+1 == s)
                            assert(sub_stage == ss)
                            assert(comp == c)
                            id = net.node_id(line+1, sub_stage, stage+1, comp)
                            (l, ss, s, c) = net.items_from_node_id(id)
                            assert(line+1 == l)
                            assert(stage+1 == s)
                            assert(sub_stage == ss)
                            assert(comp == c)
    print("All is OK!")

    def construct(self):
        nleaves = 8
        network = Batcher(nleaves, delay=1)

        Network, layout = network.get()
        minx = min([x[0] for x in layout.values()])
        maxx = max([x[0] for x in layout.values()])
        self.play(self.camera.frame.animate.set(width=maxx-minx+4))

        in_seq = [random.randint(0, nleaves) for x in range(0, nleaves)]
        network.set_input(in_seq)
        viz.generic_run(self, network)
        #print(network.read_output())

All is OK!


                                                                                     

Iteration 0


                                                                                                                    

Adding rectangle around 4 nodes
Adding rectangle around 4 nodes
Adding rectangle around 4 nodes
Adding rectangle around 4 nodes
Adding rectangle around 16 nodes
Adding rectangle around 16 nodes
Adding rectangle around 48 nodes
Iteration 1


                                                                                                                             

Iteration 2


                                                                                                                             

Iteration 3


                                                                                                                             

Iteration 4


                                                                                                                              

Iteration 5


                                                                                                                              

Iteration 6


                                                                                                                              

Iteration 7


                                                                                                                              

Iteration 8


                                                                                                                              

Iteration 9


                                                                                                                              

Iteration 10


                                                                                                                              

Iteration 11


                                                                                                                              

Iteration 12


                                                                                                                              

Iteration 13


                                                                                                                              

Iteration 14


                                                                                                                        

In [21]:
for lnleaves in range(2, 5):
    nleaves = 1<<lnleaves
    test_seq = [x for x in range(nleaves, -1, -1)]
    for i in range(0, 1000):
        print(f'{nleaves} {i}', end='\r')
        network = Batcher(nleaves, verbose = False)
        network.set_input(test_seq)
        network.init_computation()
        while network.compute_iteration():
            None

        res = network.read_output()
        #print(res)
        v = res[0]
        for x in range(1, nleaves):
            assert(int(v) <= int(res[x]))
            v = int(res[x])
        test_seq = [random.randint(-nleaves, nleaves*nleaves) for x in range(0, nleaves)]

print("All is OK!")


All is OK!
