# Playground - data flow graph 

In [87]:
import os
import networkx as nx
import pyvis as pv
import matplotlib.pyplot as plt
from typing import Iterable, Tuple, Set, List, Dict, OrderedDict
import bril_model as bm
from bril_model.form_blocks import print_blocks, form_blocks

DEMO_BRIL_FILE = "../bril/benchmarks/core/is-decreasing.bril"
# DEMO_BRIL_FILE = "../bril/benchmarks/core/hanoi.bril"
# DEMO_BRIL_FILE = "../bril/benchmarks/core/birthday.bril"

bbs = bm.BrilScript(script_name=os.path.basename(DEMO_BRIL_FILE), file_dir=os.path.dirname(DEMO_BRIL_FILE))

In [88]:
ENTRY_POINT_NAME = 'ENTRY\nPOINT'
RETURN_POINT_NAME = 'RETURN\nPOINT'

def iter_func_blocks(bs: bm.BrilScript) -> Iterable[Tuple[bm.BrilFunction, OrderedDict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]:
    for each_func in bs.functions:
        block_dict: OrderedDict[bm.BrilInstruction_Label, List[bm.BrilInstruction]] = {}
        anonymous_id = 0
        for each_block in form_blocks(each_func.instrs):
            this_block_label = None
            if isinstance(each_block[0], bm.BrilInstruction_Label):
                this_block_label = each_block[0]
                block_dict[this_block_label] = each_block[1:]
            else:
                this_block_label = bm.BrilInstruction_Label(dict(label='_f{}._anon{}'.format(str(hash(each_func.name)%1000).zfill(3), anonymous_id)))
                block_dict[this_block_label] = each_block[:]
                anonymous_id += 1
        yield (each_func, block_dict)

def update_to_graph(bbs: bm.BrilScript, app_graph_dict: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]):
    for each_func, block_dict in iter_func_blocks(bbs):
        print("Function: {}".format(each_func.name))
        fdg = nx.DiGraph()  # function directed graph
        last_label = None
        is_first = True
        for each_label, each_block in block_dict.items():
            fdg.add_node(each_label.label, instructions=[each_label]+each_block)
            if last_label:
                fdg.add_edge(last_label.label, each_label.label, reason='[fallthrough]')
            elif is_first:
                fdg.add_edge(ENTRY_POINT_NAME, each_label.label, reason='[enter]')
                is_first = False
            
            final_instr = each_block[-1]
            if final_instr.op in ['jmp', 'br']:
                for redirect_instr_to_dest in final_instr.labels:
                    fdg.add_edge(each_label.label, redirect_instr_to_dest, reason=final_instr.to_briltxt())
                last_label = None
            elif final_instr.op in ['ret']:
                fdg.add_edge(each_label.label, RETURN_POINT_NAME, reason=final_instr.to_briltxt())
                last_label = None
            else:
                last_label = each_label
        app_graph_dict[each_func] = (fdg, block_dict)

app_graph: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]] = {}
update_to_graph(bbs, app_graph)

Function: main
Function: is_decreasing
Function: last_digit


In [89]:
# for each basic block, generate set of variables and expressions that:
# 1. are used before defined
# 2. are modified in the block
# Mark used before defined variables as set(GEN) and modified variables as set(KILL)
# GEN = {v | v is used before defined}
# KILL = {v | v is modified}

def args_use_before_assign_and_assign(instrs: List[bm.BrilInstruction]) -> Tuple[Set[str],Set[str]]:
    """
    Given a list of instructions, return the set of variables that are used before defined and the set of variables that are modified.
    This is from HW2, but modified to return the sets instead of printing them.
    """
    used_first: Set[str] = set()
    written: Set[str] = set()
    for instr in instrs:
        used_first.update(set(instr.args if instr.args else []) - written)
        if instr.dest:
            written.add(instr.dest)
    return used_first, written

def gen_kill_sets(block: List[bm.BrilInstruction]) -> Tuple[Set[str],Set[str]]:
    return args_use_before_assign_and_assign(block)

def update_gen_kill_sets(app_graph: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]):
    for _, (fdg, _) in app_graph.items():
        for each_node, each_node_data in fdg.nodes(data=True):
            each_block: List[bm.BrilInstruction] = each_node_data.get('instructions', None)
            if each_block:
                _gen, _kill = gen_kill_sets(each_block)
            else:
                _gen, _kill = set(), set()
            fdg.nodes[each_node]['GEN'] = _gen
            fdg.nodes[each_node]['KILL'] = _kill

update_gen_kill_sets(app_graph)

_get_node_gen_set = lambda fdg, node_name: fdg.nodes[node_name].get('GEN', set())
_get_node_kill_set = lambda fdg, node_name: fdg.nodes[node_name].get('KILL', set())
_get_node_in_set = lambda fdg, node_name: fdg.nodes[node_name].get('IN', set())
_get_node_out_set = lambda fdg, node_name: fdg.nodes[node_name].get('OUT', set())
_get_node_pred_set = lambda fdg, node_name: set(fdg.predecessors(node_name))
_get_node_succ_set = lambda fdg, node_name: set(fdg.successors(node_name))

def _fdg_update_liveness_sets(fdg: nx.DiGraph) -> bool:
    has_changed = False
    for this_node, _ in fdg.nodes(data=True):
        if this_node == RETURN_POINT_NAME:
            _out = set()
        else:
            _temp_succ_req = [_get_node_gen_set(fdg, each_succ_node) for each_succ_node in _get_node_succ_set(fdg, this_node)]
            if _temp_succ_req:
                _out = set.union(*_temp_succ_req)
            else:
                _out = set()
        _in = set.union(_get_node_gen_set(fdg, this_node), set.difference(_out, _get_node_kill_set(fdg, this_node)))
        if _in != _get_node_in_set(fdg, this_node) or _out != _get_node_out_set(fdg, this_node):
            fdg.nodes[this_node]['IN'] = _in
            fdg.nodes[this_node]['OUT'] = _out
            has_changed = True
    return has_changed

def update_liveness_sets(app_graph: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]):
    has_changed = True
    while has_changed:
        has_changed = False
        print("Updating liveness sets")
        for _, (fdg, _) in app_graph.items():
            has_changed |= _fdg_update_liveness_sets(fdg)

update_liveness_sets(app_graph)

Updating liveness sets
Updating liveness sets


In [90]:
app_graph

{BrilFunction ::	main ( x<int> ) -> None: <3 instr>: (<networkx.classes.digraph.DiGraph at 0x7f0618c1cf40>,
  {BrilInstruction_Label ::	<|._f142._anon0|>: [BrilInstruction_ValOp ::	tmp0 (bool) <- [call] is_decreasing ,
    BrilInstruction_ValOp ::	tmp (bool) <- [id] tmp0 ,
    BrilInstruction_EffOp ::	[print] tmp ]}),
 BrilFunction ::	is_decreasing ( x<int> ) -> bool: <29 instr>: (<networkx.classes.digraph.DiGraph at 0x7f0618c1d660>,
  {BrilInstruction_Label ::	<|._f644._anon0|>: [BrilInstruction_ValOp ::	tmp (int) <- [id] x ,
    BrilInstruction_Const ::	tmp1 (int) <- [const] 1,
    BrilInstruction_Const ::	tmp2 (int) <- [const] -1,
    BrilInstruction_ValOp ::	tmp3 (int) <- [mul] tmp1, tmp2 ,
    BrilInstruction_ValOp ::	prev (int) <- [id] tmp3 ],
   BrilInstruction_Label ::	<|.label4|>: [BrilInstruction_Const ::	tmp7 (int) <- [const] 0,
    BrilInstruction_ValOp ::	tmp8 (bool) <- [gt] tmp, tmp7 ,
    BrilInstruction_EffOp ::	[br] tmp8 {label5, label6} ],
   BrilInstruction_Label ::	

In [108]:
fdg = app_graph.get(bbs.functions[1])[0]

# Plot with pyvis
net = pv.network.Network(
    directed=True,
    neighborhood_highlight=True,
    notebook=True,
    cdn_resources="remote", 
    height="100vh",
)

# net.show_buttons(['nodes', 'edges', 'layout', 'interaction', 'manipulation', 'physics', 'selection', 'renderer']) # Show part 3 in the plot (optional)
net.from_nx(fdg) # Create directly from nx graph
_net_node_name2idx: Dict[str, int] = {node['id']: idx for idx, node in enumerate(net.nodes)}
_map_node_name2idx = lambda node_name: _net_node_name2idx.get(node_name, None)

# net.show_buttons(['physics',])

net.options.interaction.hover = True
net.options.physics.use_force_atlas_2based(dict(
    gravity=-100, 
    spring_strength=0.05,
    central_gravity=0.01,
    spring_length=200,
    damping=0.45,
    overlap=0,
))

def generate_set_str(showing_set: set) -> str:
    return '{' + ','.join(showing_set) + '}'

# Traverse the net nodes of PyVis and convert the data
for node in net.nodes:
    _in = generate_set_str(node.pop('IN', set()))
    _out = generate_set_str(node.pop('OUT', set()))
    _gen = generate_set_str(node.pop('GEN', set()))
    _kill = generate_set_str(node.pop('KILL', set()))

    node['title'] = ''
    if 'instructions' in node:
        # remove 'data' key from node, and set 'title' key to the string representation of the data
        node['title'] = "\n  ".join([obj.to_briltxt() if hasattr(obj, 'to_briltxt') else str(obj) for obj in node.pop('instructions', [])])
        node['shape'] = 'box'
    node['title'] += "\n\n" + f"GEN: { _gen }" + "\n" + f"KILL: { _kill}"
    node['title'] = node['title'].strip(' \t\n\r')
    # title word layout change
    # find and replace double return with dash line
    # find and replace double space with full corner single space
    node['title'] = node['title'].replace('\n\n', '\n--------\n').replace('  ', '\u3000')

    _liveness_labels = f"IN: {_in}", f"OUT: {_out}"
    node['label'] = _liveness_labels[0] + '\n' + node['label'] + '\n' + _liveness_labels[1]

    if node['id'] in (ENTRY_POINT_NAME, RETURN_POINT_NAME):
        node['color'] = 'grey'
        node['shape'] = 'circle'

for edge in net.edges:
    _reason = edge.pop('reason', None)
    if _reason:
        edge['label'] = _reason
    # _src_node = net.nodes[_map_node_name2idx( edge['from'] )]
    # _dst_node = net.nodes[_map_node_name2idx( edge['to'] )]
    # if _src_node.get('out', set()) & _dst_node.get('in', set()):
    #     _liveness_src_out_label = ', '.join(_src_node.get('out', set()) & _dst_node.get('in', set()))

net.save_graph('test.html')

In [83]:
print_blocks(bbs)

BrilFunction ::	main ( x<int> ) -> None: <3 instr>
anonymous block:
  BrilInstruction_ValOp ::	tmp0 (bool) <- [call] is_decreasing 
  BrilInstruction_ValOp ::	tmp (bool) <- [id] tmp0 
  BrilInstruction_EffOp ::	[print] tmp 

BrilFunction ::	is_decreasing ( x<int> ) -> bool: <29 instr>
anonymous block:
  BrilInstruction_ValOp ::	tmp (int) <- [id] x 
  BrilInstruction_Const ::	tmp1 (int) <- [const] 1
  BrilInstruction_Const ::	tmp2 (int) <- [const] -1
  BrilInstruction_ValOp ::	tmp3 (int) <- [mul] tmp1, tmp2 
  BrilInstruction_ValOp ::	prev (int) <- [id] tmp3 
block "label4":
  BrilInstruction_Const ::	tmp7 (int) <- [const] 0
  BrilInstruction_ValOp ::	tmp8 (bool) <- [gt] tmp, tmp7 
  BrilInstruction_EffOp ::	[br] tmp8 {label5, label6} 
block "label5":
  BrilInstruction_ValOp ::	tmp9 (int) <- [call] last_digit 
  BrilInstruction_ValOp ::	digit (int) <- [id] tmp9 
  BrilInstruction_ValOp ::	tmp10 (bool) <- [lt] digit, prev 
  BrilInstruction_EffOp ::	[br] tmp10 {label11, label12} 
block "