# Playground - data flow graph 

In [21]:
import os
import networkx as nx
import pyvis as pv
import matplotlib.pyplot as plt
from typing import Iterable, Tuple, Set, List, Dict, OrderedDict, Optional
import bril_model as bm
from bril_model.form_blocks import print_blocks, form_blocks

SAVE_DIR = "./save"

# DEMO_BRIL_FILE = "./example/in_class_example_1.bril"
# DEMO_BRIL_FILE = "./example/in_class_example_2.bril"
DEMO_BRIL_FILE = "./example/in_class_example_3.bril"
# DEMO_BRIL_FILE = "../bril/benchmarks/core/is-decreasing.bril"
# DEMO_BRIL_FILE = "../bril/benchmarks/core/hanoi.bril"
# DEMO_BRIL_FILE = "../bril/benchmarks/core/birthday.bril"
# DEMO_BRIL_FILE = "../bril/examples/test/df/fact.bril"
# DEMO_BRIL_FILE = "../bril/examples/test/df/cond.bril"
# DEMO_BRIL_FILE = "../bril/examples/test/df/cond-args.bril"
# DEMO_BRIL_FILE = "../bril/examples/test/dom/loopcond.bril"
# DEMO_BRIL_FILE = "../bril/examples/test/dom/while.bril"

In [22]:
ENTRY_POINT_NAME = 'ENTRY'
RETURN_POINT_NAME = 'RETURN'

In [23]:
class BasicBlock():
    def __init__(self, label: bm.BrilInstruction_Label, instrs: List[bm.BrilInstruction]):
        if not isinstance(label, bm.BrilInstruction_Label):
            raise TypeError(f"Expected label, got {type(label)}")
        if label not in instrs:
            instrs.insert(0, label)
        self.instrs = instrs
        self.succ: Set[BasicBlock] = set()
        self.pred: Set[BasicBlock] = set()
    
    @property
    def label(self):
        return self.instrs[0]
    
    def __hash__(self) -> int:
        return hash(self.label)
    def __eq__(self, o: object) -> bool:
        if not isinstance(o, BasicBlock):
            return False
        return self.label == o.label
    def __lt__(self, o: object) -> bool:
        if not isinstance(o, BasicBlock):
            return False
        # return block is always the last one
        if self.label.label == RETURN_POINT_NAME: return False
        if o.label.label == RETURN_POINT_NAME: return True
        # entry block is always the first one
        if self.label.label == ENTRY_POINT_NAME: return True
        if o.label.label == ENTRY_POINT_NAME: return False
        return self.label < o.label

    def __str__(self):
        if self.label.label in {ENTRY_POINT_NAME, RETURN_POINT_NAME}: return self.label.label
        _instrs = '\n  '.join([x.to_briltxt() for x in self.instrs[1:]])
        if _instrs:
            _instrs = f"\n  {_instrs}"
        return f"{self.label.to_briltxt()}{_instrs}"
    def __repr__(self):
        return f"{self.__class__.__name__} :: {self.label.to_briltxt()}..[{len(self.instrs) - 1}]"

In [24]:
def iter_func_blocks(bs: bm.BrilScript) -> Iterable[Tuple[bm.BrilFunction, List[BasicBlock]]]:
    for each_func in bs.functions:
        bbs: List[BasicBlock] = list()
        anonymous_id = 0
        for each_block in form_blocks(each_func.instrs):
            this_block_label: bm.BrilInstruction_Label = None
            if isinstance(each_block[0], bm.BrilInstruction_Label):
                this_block_label = each_block[0]
            else:
                this_block_label = bm.BrilInstruction_Label(dict(label='_f{}._anon{}'.format(str(hash(each_func.name)%1000).zfill(3), anonymous_id)))
                anonymous_id += 1
            bbs.append(BasicBlock(this_block_label, each_block))
        yield (each_func, bbs)

def update_to_graph(bscript: bm.BrilScript) -> Dict[bm.BrilFunction, List[BasicBlock]]:
    app_bbs_dict: OrderedDict[bm.BrilFunction, List[BasicBlock]] = {}
    for each_func, basic_blocks in iter_func_blocks(bscript):
        print("Function: {}".format(each_func.name))
        entry_bb = BasicBlock(bm.BrilInstruction_Label(dict(label=ENTRY_POINT_NAME)), [])
        return_bb = BasicBlock(bm.BrilInstruction_Label(dict(label=RETURN_POINT_NAME)), [])
        prev_bb: Optional[BasicBlock] = None
        is_first: bool = True
        for each_bb in basic_blocks:
            if prev_bb:
                # reason: fallthrough edge
                prev_bb.succ.add(each_bb)
                each_bb.pred.add(prev_bb)
            elif is_first:
                # reason: entry edge
                entry_bb.succ.add(each_bb)
                each_bb.pred.add(entry_bb)
                is_first = False
            
            # get last instruction if exists
            final_instr = each_bb.instrs[-1] if each_bb.instrs else None
            if final_instr is None:
                # empty block, skip
                prev_bb = each_bb
            elif final_instr.op in ['jmp', 'br']:
                # reaching control flow instruction
                for redirect_instr_to_bb_label in final_instr.labels:
                    next_bb = next((bb for bb in basic_blocks if bb.label.label == redirect_instr_to_bb_label), None)
                    if next_bb is None:
                        raise ValueError(f"Cannot find block with label {redirect_instr_to_bb_label}")
                    # reason: control flow edge
                    each_bb.succ.add(next_bb)
                    next_bb.pred.add(each_bb)
                prev_bb = None
            elif final_instr.op in ['ret']:
                # reaching return instruction
                # reason: return edge
                each_bb.succ.add(return_bb)
                return_bb.pred.add(each_bb)
                prev_bb = None
                # explicit return block, no fallthrough
            else:
                # normal instruction, prepare for fallthrough
                prev_bb = each_bb
        # check if last block has no fallthrough
        if prev_bb:
            # reason: return edge
            prev_bb.succ.add(return_bb)
            return_bb.pred.add(prev_bb)
        # add entry/return block
        basic_blocks.insert(0, entry_bb)
        basic_blocks.append(return_bb)
        app_bbs_dict[each_func] = basic_blocks
    return app_bbs_dict

def getBasicBlockByLabel(bb_list: List[BasicBlock], label: str) -> Optional[BasicBlock]:
    return next((bb for bb in bb_list if bb.label.label == label), None)

# load the bril script
bscript = bm.BrilScript(script_name=os.path.basename(DEMO_BRIL_FILE), file_dir=os.path.dirname(DEMO_BRIL_FILE))
app_graph: Dict[bm.BrilFunction, List[BasicBlock]] = update_to_graph(bscript)

Function: in_class_example_3
Function: main


In [25]:
app_func_list = list(app_graph.keys())
print("Functions: {}".format(', '.join([f"[{idx}]={x.name}" for idx, x in enumerate(app_func_list)])))
test_func = app_func_list[0]

Functions: [0]=in_class_example_3, [1]=main


In [26]:
def generate_dom_dict(bbs: List[BasicBlock]) -> dict[BasicBlock, set[BasicBlock]]:
    # Initialize dominator sets for each block
    # Initially, every block dominates every other block
    dom: Dict[BasicBlock, Set[BasicBlock]] = {bb: set(bbs) for bb in bbs}
    # The entry block only dominates itself
    bb_entry = getBasicBlockByLabel(bbs, ENTRY_POINT_NAME)
    dom[bb_entry] = set([bb_entry])

    changed: bool = True
    while changed:
        changed = False
        for each_bb in bbs:
            if each_bb == bb_entry: continue
            # The dominator set of a bb is the intersection of the dominator sets of its predecessors
            this_bb_dom = set.intersection(*[dom[each_pred] for each_pred in each_bb.pred])
            # The dominator set of a bb should includes itself
            this_bb_dom.add(each_bb)
            # Update if changed
            if dom[each_bb] != this_bb_dom:
                dom[each_bb] = this_bb_dom
                changed = True
    return dom

# get the last dom of a given block
lambda_get_lease_superset_dom = lambda dom_dict, bb: next((test_dom for test_dom in dom_dict[bb] if ((dom_dict[bb] - dom_dict[test_dom]) == set([bb]))), None)

def get_lease_superset_dom(dom_dict: dict[BasicBlock, set[BasicBlock]], bb: BasicBlock) -> Optional[BasicBlock]:
    return lambda_get_lease_superset_dom(dom_dict, bb)


In [27]:
print(f"{bscript}")
print()
for each_func, bbs in app_graph.items():
    print(each_func)
    dom = generate_dom_dict(bbs)
    _max_bb_label_len = max([len(x.label.label) for x in bbs])
    for each_bb, dom_set in dom.items():
        print(f"  {each_bb.label.label.ljust(_max_bb_label_len)}: {', '.join([x.label.label for x in sorted(dom_set)])}")
    print()

BrilScript ::	in_class_example_3.bril <2 func> @ ./example

BrilFunction ::	in_class_example_3 ( input<int> ) -> None: <23 instr>
  ENTRY       : ENTRY
  _f795._anon0: ENTRY, _f795._anon0
  bb1         : ENTRY, _f795._anon0, bb1
  bb2         : ENTRY, _f795._anon0, bb1, bb2
  bb_compare  : ENTRY, _f795._anon0, bb1, bb2, bb_compare
  bb_loop     : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_loop
  bb_loop_1   : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_loop, bb_loop_1
  bb_loop_2   : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_loop, bb_loop_1, bb_loop_2
  bb_return   : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_return
  RETURN      : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_return, RETURN

BrilFunction ::	main (  ) -> None: <2 instr>
  ENTRY       : ENTRY
  _f159._anon0: ENTRY, _f159._anon0
  RETURN      : ENTRY, _f159._anon0, RETURN



In [28]:
import graphviz


def get_style_attrs(bb: BasicBlock) -> dict:
    _addi_attrs = dict()
    if bb.label.label in [ENTRY_POINT_NAME, RETURN_POINT_NAME]:
        _addi_attrs['shape'] = 'doublecircle'
        _addi_attrs['style'] = 'filled'
        _addi_attrs['fillcolor'] = 'lightgray'
    else:
        _addi_attrs['shape'] = 'box'
    return _addi_attrs

def create_graphviz_dot_node(dot: graphviz.Digraph, func: bm.BrilFunction, bb: BasicBlock) -> None:
    dot.node(f"FUNC_{func.name}_NODE_{bb.label.label}", label=f"{bb}", **get_style_attrs(bb))

def create_graphviz_dot_edge(dot: graphviz.Digraph, func: bm.BrilFunction, from_bb: BasicBlock, to_bb: BasicBlock) -> None:
    dot.edge(f"FUNC_{func.name}_NODE_{from_bb.label.label}", f"FUNC_{func.name}_NODE_{to_bb.label.label}")

def create_graphviz_dot(app_graph: Dict[bm.BrilFunction, List[BasicBlock]]) -> Tuple[graphviz.Digraph, graphviz.Digraph]:
    dot_cfg = graphviz.Digraph('CFG', comment='Control Flow Graph')
    dot_domt = graphviz.Digraph('DOMTREE', comment='Dominator Tree')
    for each_func, bbs in app_graph.items():
        
        print(each_func)
        dom = generate_dom_dict(bbs)
        _max_bb_label_len = max([len(x.label.label) for x in bbs])

        # each func is a subgraph
        with dot_cfg.subgraph(name=f"cluster_{each_func.name}") as dot_cfg_func, dot_domt.subgraph(name=f"cluster_{each_func.name}") as dot_domt_func:
            dot_cfg_func.attr(label=each_func.name)
            dot_cfg_func.attr(color='#f7f7f7')
            dot_cfg_func.attr(style='filled')
            dot_cfg_func.attr(rankdir='TB')
            dot_domt_func.attr(label=each_func.name)
            dot_domt_func.attr(color='#f7f7f7')
            dot_domt_func.attr(style='filled')
            dot_domt_func.attr(rankdir='TB')

            # generate cfg dot graph
            for each_bb in bbs:
                create_graphviz_dot_node(dot_cfg_func, each_func, each_bb)
                for each_succ in each_bb.succ:
                    create_graphviz_dot_edge(dot_cfg_func, each_func, each_bb, each_succ)

            # generate domt dot graph
            for each_bb in dom.keys():
                least_superset_dom = get_lease_superset_dom(dom, each_bb)
                print(f"  {each_bb.label.label.ljust(_max_bb_label_len)} <- {(least_superset_dom.label.label if least_superset_dom else '[ROOT]').ljust(_max_bb_label_len)} : {', '.join([x.label.label for x in sorted(dom[each_bb])])}")
                create_graphviz_dot_node(dot_domt_func, each_func, each_bb)
                # dot_domt_func.node(f"{each_func.name}__{each_bb.label.label}", label=f"{each_bb}", **get_style_attrs(each_bb))
                if least_superset_dom:
                    create_graphviz_dot_edge(dot_domt_func, each_func, least_superset_dom, each_bb)
                    # dot_domt_func.edge(f"{each_func.name}__{least_superset_dom.label.label}", f"{each_func.name}__{each_bb.label.label}")
            
            print()
    return dot_cfg, dot_domt

dot_cfg, dot_domt = create_graphviz_dot(app_graph)
dot_cfg.render(os.path.join(SAVE_DIR, bscript.script_name+".cfg.gv"), view=False)
dot_domt.render(os.path.join(SAVE_DIR, bscript.script_name+".domt.gv"), view=False)


BrilFunction ::	in_class_example_3 ( input<int> ) -> None: <23 instr>
  ENTRY        <- [ROOT]       : ENTRY
  _f795._anon0 <- ENTRY        : ENTRY, _f795._anon0
  bb1          <- _f795._anon0 : ENTRY, _f795._anon0, bb1
  bb2          <- bb1          : ENTRY, _f795._anon0, bb1, bb2
  bb_compare   <- bb2          : ENTRY, _f795._anon0, bb1, bb2, bb_compare
  bb_loop      <- bb_compare   : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_loop
  bb_loop_1    <- bb_loop      : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_loop, bb_loop_1
  bb_loop_2    <- bb_loop_1    : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_loop, bb_loop_1, bb_loop_2
  bb_return    <- bb_compare   : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_return
  RETURN       <- bb_return    : ENTRY, _f795._anon0, bb1, bb2, bb_compare, bb_return, RETURN

BrilFunction ::	main (  ) -> None: <2 instr>
  ENTRY        <- [ROOT]       : ENTRY
  _f159._anon0 <- ENTRY        : ENTRY, _f159._anon0
  RETURN       <- _f159._anon0 : ENTRY,

'save/in_class_example_3.bril.domt.gv.pdf'

In [29]:
for each_func, bbs in app_graph.items():
    print(each_func)
    dom = generate_dom_dict(bbs)
    _max_bb_label_len = max([len(x.label.label) for x in bbs])
    for each_bb in dom.keys():
        least_superset_dom = lambda_get_lease_superset_dom(dom, each_bb)
        print(f"  {each_bb.label.label.ljust(_max_bb_label_len)}: {(least_superset_dom.label.label if least_superset_dom else 'None ?????').ljust(_max_bb_label_len)} ~~~~~")

        for test_dom in dom[each_bb]:
            if test_dom == each_bb: continue
            diff =  dom[each_bb] -  dom[test_dom]
            if diff == set([each_bb]): 
                print(f"  {each_bb.label.label.ljust(_max_bb_label_len)}: {test_dom.label.label.ljust(_max_bb_label_len)} !!!!!")
            else:
                print(f"  {each_bb.label.label.ljust(_max_bb_label_len)}: {test_dom.label.label.ljust(_max_bb_label_len)} + {', '.join([x.label.label for x in sorted(diff)])}")
        print()
    print()

BrilFunction ::	in_class_example_3 ( input<int> ) -> None: <23 instr>
  ENTRY       : None ?????   ~~~~~

  _f795._anon0: ENTRY        ~~~~~
  _f795._anon0: ENTRY        !!!!!

  bb1         : _f795._anon0 ~~~~~
  bb1         : _f795._anon0 !!!!!
  bb1         : ENTRY        + _f795._anon0, bb1

  bb2         : bb1          ~~~~~
  bb2         : _f795._anon0 + bb1, bb2
  bb2         : ENTRY        + _f795._anon0, bb1, bb2
  bb2         : bb1          !!!!!

  bb_compare  : bb2          ~~~~~
  bb_compare  : _f795._anon0 + bb1, bb2, bb_compare
  bb_compare  : ENTRY        + _f795._anon0, bb1, bb2, bb_compare
  bb_compare  : bb2          !!!!!
  bb_compare  : bb1          + bb2, bb_compare

  bb_loop     : bb_compare   ~~~~~
  bb_loop     : _f795._anon0 + bb1, bb2, bb_compare, bb_loop
  bb_loop     : bb_compare   !!!!!
  bb_loop     : bb2          + bb_compare, bb_loop
  bb_loop     : ENTRY        + _f795._anon0, bb1, bb2, bb_compare, bb_loop
  bb_loop     : bb1          + bb2, bb_compar

In [30]:
print_blocks(bscript)

BrilFunction ::	in_class_example_3 ( input<int> ) -> None: <23 instr>
anonymous block:
  BrilInstruction_ValOp ::	x (int) <- [id] input 
  BrilInstruction_Const ::	one (int) <- [const] 1
  BrilInstruction_Const ::	two (int) <- [const] 2
block "bb1":
  BrilInstruction_ValOp ::	a (int) <- [sub] x, one 
block "bb2":
  BrilInstruction_ValOp ::	b (int) <- [sub] x, two 
block "bb_compare":
  BrilInstruction_Const ::	zero (int) <- [const] 0
  BrilInstruction_ValOp ::	cond (bool) <- [gt] x, zero 
  BrilInstruction_EffOp ::	[br] cond {bb_loop, bb_return} 
block "bb_loop":
block "bb_loop_1":
  BrilInstruction_ValOp ::	totp1 (int) <- [mul] a, b 
  BrilInstruction_ValOp ::	totp2 (int) <- [sub] totp1, x 
  BrilInstruction_EffOp ::	[print] totp2 
block "bb_loop_2":
  BrilInstruction_ValOp ::	x (int) <- [sub] x, one 
  BrilInstruction_EffOp ::	[jmp]  {bb_compare} 
block "bb_return":
  BrilInstruction_ValOp ::	totp1 (int) <- [mul] a, b 
  BrilInstruction_EffOp ::	[print] totp1 
  BrilInstruction_EffOp