# Playground - data flow graph 

In [58]:
import os
import networkx as nx
import pyvis as pv
import matplotlib.pyplot as plt
from typing import Iterable, Tuple, Set, List, Dict, OrderedDict
import bril_model as bm
from bril_model.form_blocks import print_blocks, form_blocks

# DEMO_BRIL_FILE = "../bril/benchmarks/core/is-decreasing.bril"
# DEMO_BRIL_FILE = "../bril/benchmarks/core/hanoi.bril"
DEMO_BRIL_FILE = "../bril/benchmarks/core/birthday.bril"
DEMO_BRIL_FILE = "../bril/examples/test/df/fact.bril"
DEMO_BRIL_FILE = "../bril/examples/test/df/cond.bril"
DEMO_BRIL_FILE = "../bril/examples/test/df/cond-args.bril"

bbs = bm.BrilScript(script_name=os.path.basename(DEMO_BRIL_FILE), file_dir=os.path.dirname(DEMO_BRIL_FILE))

In [59]:
ANALYSIS = "liveness"
# ANALYSIS = "availabililty"

In [60]:

class Expr:
    """
    A Expr uniquely represents a computation in terms of sub-values.
    This is from HW2, but change `argnums` back to `args`
    """
    def __init__(self, op: str, args: List[str]):
        self.op = op
        if op in ['add', 'mul', 'and', 'or', 'eq']:
            self.args = sorted(args)
        self.args = args
    def __hash__(self):
        _hash_tuple = tuple([self.op, tuple(self.args)])
        return hash(_hash_tuple)
    def __eq__(self, other):
        return self.op == other.op and self.args == other.args
    def __repr__(self):
        return f"{__class__.__name__} {str(self)}"
    def __str__(self):
        return f"{self.op}({','.join(self.args)})"

In [61]:
ENTRY_POINT_NAME = 'ENTRY\nPOINT'
RETURN_POINT_NAME = 'RETURN\nPOINT'

def iter_func_blocks(bs: bm.BrilScript) -> Iterable[Tuple[bm.BrilFunction, OrderedDict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]:
    for each_func in bs.functions:
        block_dict: OrderedDict[bm.BrilInstruction_Label, List[bm.BrilInstruction]] = {}
        anonymous_id = 0
        for each_block in form_blocks(each_func.instrs):
            this_block_label = None
            if isinstance(each_block[0], bm.BrilInstruction_Label):
                this_block_label = each_block[0]
                block_dict[this_block_label] = each_block[1:]
            else:
                this_block_label = bm.BrilInstruction_Label(dict(label='_f{}._anon{}'.format(str(hash(each_func.name)%1000).zfill(3), anonymous_id)))
                block_dict[this_block_label] = each_block[:]
                anonymous_id += 1
        yield (each_func, block_dict)

def update_to_graph(bbs: bm.BrilScript, app_graph_dict: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]):
    for each_func, block_dict in iter_func_blocks(bbs):
        print("Function: {}".format(each_func.name))
        fdg = nx.DiGraph()  # function directed graph
        last_label = None
        is_first = True
        for each_label, each_block in block_dict.items():
            fdg.add_node(each_label.label, instructions=[each_label]+each_block)
            if last_label:
                fdg.add_edge(last_label.label, each_label.label, reason='[fallthrough]')
            elif is_first:
                fdg.add_edge(ENTRY_POINT_NAME, each_label.label, reason='[enter]')
                is_first = False
            
            final_instr = each_block[-1]
            if final_instr.op in ['jmp', 'br']:
                for redirect_instr_to_dest in final_instr.labels:
                    fdg.add_edge(each_label.label, redirect_instr_to_dest, reason=final_instr.to_briltxt())
                last_label = None
            elif final_instr.op in ['ret']:
                fdg.add_edge(each_label.label, RETURN_POINT_NAME, reason=final_instr.to_briltxt())
                last_label = None
            else:
                last_label = each_label
        app_graph_dict[each_func] = (fdg, block_dict)

app_graph: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]] = {}
update_to_graph(bbs, app_graph)

Function: main


In [62]:
# for each basic block, generate set of variables and expressions that:
# 1. are used before defined
# 2. are modified in the block
# Mark used before defined variables as set(GEN) and modified variables as set(KILL)
# GEN = {v | v is used before defined}
# KILL = {v | v is modified}
# EXPR = {e | e is computed in the block}

def get__args_used_before_assign__assigned__calc_expr_available_at_bb_end(instrs: List[bm.BrilInstruction]) -> Tuple[Set[str],Set[str],Set[Expr]]:
    """
    Given a list of instructions, return the set of variables that are used before defined and the set of variables that are modified.
    This is from HW2, but modified to return the sets instead of printing them.
    """
    used_first: Set[str] = set()
    written: Set[str] = set()
    avail_exprs: Set[Expr] = set()
    for instr in instrs:
        used_first.update(set(instr.args if instr.args else []) - written)
        if instr.dest:
            # check if the dest was used in generating any of the expressions, if so, remove it from exprs
            for expr in list(avail_exprs):
                if instr.dest in expr.args:
                    avail_exprs.remove(expr)
            written.add(instr.dest)
            if instr.args and instr.op not in ['id', 'const', 'call']:
                avail_exprs.add(Expr(instr.op, instr.args))
    return used_first, written, avail_exprs

# for each basic block, generate set of expressions that are being computed.
# def gen_expr_sets(block: List[bm.BrilInstruction]) -> Set[Expr]:
#     for instr in block:
#         if instr.args and instr.dest:
#     return exprs

def gen_kill_expr_sets(block: List[bm.BrilInstruction]) -> Tuple[Set[str],Set[str],Set[Expr]]:
    return get__args_used_before_assign__assigned__calc_expr_available_at_bb_end(block)

def update_gen_kill_sets(app_graph: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]):
    for _, (fdg, _) in app_graph.items():
        for each_node, each_node_data in fdg.nodes(data=True):
            each_block: List[bm.BrilInstruction] = each_node_data.get('instructions', None)
            _gen, _kill, _expr = gen_kill_expr_sets(each_block) if each_block else (set(), set(), set())
            fdg.nodes[each_node]['GEN'] = _gen
            fdg.nodes[each_node]['KILL'] = _kill
            fdg.nodes[each_node]['EXPR'] = _expr

update_gen_kill_sets(app_graph)

_get_node_gen_set = lambda fdg, node_name: fdg.nodes[node_name].get('GEN', set())
_get_node_kill_set = lambda fdg, node_name: fdg.nodes[node_name].get('KILL', set())
_get_node_expr_set = lambda fdg, node_name: fdg.nodes[node_name].get('EXPR', set())
_get_node_in_set = lambda fdg, node_name: fdg.nodes[node_name].get('IN', set())
_get_node_out_set = lambda fdg, node_name: fdg.nodes[node_name].get('OUT', set())
_get_node_pred_set = lambda fdg, node_name: set(fdg.predecessors(node_name))
_get_node_succ_set = lambda fdg, node_name: set(fdg.successors(node_name))


In [63]:
def _fdg_update_liveness_sets(fdg: nx.DiGraph) -> bool:
    has_changed = False
    for this_node, _ in fdg.nodes(data=True):
        if this_node == RETURN_POINT_NAME:
            _out = set()
        else:
            _temp_succ_req = [_get_node_in_set(fdg, each_succ_node) for each_succ_node in _get_node_succ_set(fdg, this_node)]
            _print_this_node_name = this_node.replace('\n', '\\n')
            _out = set.union(*_temp_succ_req) if _temp_succ_req else set()
            print(f"Node: {_print_this_node_name}, Succ: {_get_node_succ_set(fdg, this_node)}=>{_temp_succ_req}, OUT: {_out}")
        _in = set.union(_get_node_gen_set(fdg, this_node), set.difference(_out, _get_node_kill_set(fdg, this_node)))
        if _in != _get_node_in_set(fdg, this_node) or _out != _get_node_out_set(fdg, this_node):
            fdg.nodes[this_node]['IN'] = _in
            fdg.nodes[this_node]['OUT'] = _out
            has_changed = True
    return has_changed

def _fdg_update_availability_sets(fdg: nx.DiGraph) -> bool:
    has_changed = False
    for this_node, _ in fdg.nodes(data=True):
        if this_node == ENTRY_POINT_NAME:
            _in = set()
        else:
            _temp_pred_give = [_get_node_out_set(fdg, each_pred_node) for each_pred_node in _get_node_pred_set(fdg, this_node)]
            _in = set.intersection(*_temp_pred_give) if _temp_pred_give else set()
        # gatter all exprs that were computed before the end of this block 
        _out: Set[Expr] = set.union(_get_node_in_set(fdg, this_node), _get_node_expr_set(fdg, this_node))
        # remove such expr that include variables that were modified in this block
        _out = set([each_expr for each_expr in _out if not set.intersection(set(each_expr.args), _get_node_kill_set(fdg, this_node))])
        _print_this_node_name = this_node.replace('\n', '\\n')
        print(f"Node: {_print_this_node_name}, IN: {_in}, OUT: {_out}")
        if _in != _get_node_in_set(fdg, this_node) or _out != _get_node_out_set(fdg, this_node):
            fdg.nodes[this_node]['IN'] = _in
            fdg.nodes[this_node]['OUT'] = _out
            has_changed = True
    return has_changed

def update_liveness_sets(app_graph: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]):
    has_changed = True
    while has_changed:
        has_changed = False
        print("Updating liveness sets")
        for _, (fdg, _) in app_graph.items():
            has_changed |= _fdg_update_liveness_sets(fdg)

def update_availability_sets(app_graph: Dict[bm.BrilFunction, Tuple[nx.DiGraph, Dict[bm.BrilInstruction_Label, List[bm.BrilInstruction]]]]):
    has_changed = True
    while has_changed:
        has_changed = False
        print("Updating availability sets")
        for _, (fdg, _) in app_graph.items():
            has_changed |= _fdg_update_availability_sets(fdg)

if ANALYSIS == 'liveness':
    update_liveness_sets(app_graph)
elif ANALYSIS == 'availabililty':
    update_availability_sets(app_graph)


Updating liveness sets
Node: _f574._anon0, Succ: {'right', 'left'}=>[set(), set()], OUT: set()
Node: ENTRY\nPOINT, Succ: {'_f574._anon0'}=>[{'cond'}], OUT: {'cond'}
Node: left, Succ: {'end'}=>[set()], OUT: set()
Node: right, Succ: {'end'}=>[set()], OUT: set()
Node: end, Succ: set()=>[], OUT: set()
Updating liveness sets
Node: _f574._anon0, Succ: {'right', 'left'}=>[set(), set()], OUT: set()
Node: ENTRY\nPOINT, Succ: {'_f574._anon0'}=>[{'cond'}], OUT: {'cond'}
Node: left, Succ: {'end'}=>[{'a', 'c'}], OUT: {'a', 'c'}
Node: right, Succ: {'end'}=>[{'a', 'c'}], OUT: {'a', 'c'}
Node: end, Succ: set()=>[], OUT: set()
Updating liveness sets
Node: _f574._anon0, Succ: {'right', 'left'}=>[set(), {'a'}], OUT: {'a'}
Node: ENTRY\nPOINT, Succ: {'_f574._anon0'}=>[{'cond'}], OUT: {'cond'}
Node: left, Succ: {'end'}=>[{'a', 'c'}], OUT: {'a', 'c'}
Node: right, Succ: {'end'}=>[{'a', 'c'}], OUT: {'a', 'c'}
Node: end, Succ: set()=>[], OUT: set()
Updating liveness sets
Node: _f574._anon0, Succ: {'right', 'lef

In [64]:
fdg = app_graph.get(bbs.functions[0])[0]

# Plot with pyvis
net = pv.network.Network(
    directed=True,
    neighborhood_highlight=True,
    notebook=True,
    cdn_resources="remote", 
    height="100vh",
    width="100vw",
)

# net.show_buttons(['nodes', 'edges', 'layout', 'interaction', 'manipulation', 'physics', 'selection', 'renderer']) # Show part 3 in the plot (optional)
net.from_nx(fdg) # Create directly from nx graph
_net_node_name2idx: Dict[str, int] = {node['id']: idx for idx, node in enumerate(net.nodes)}
_map_node_name2idx = lambda node_name: _net_node_name2idx.get(node_name, None)

# net.show_buttons(['physics',])

net.options.interaction.hover = True
net.options.physics.use_force_atlas_2based(dict(
    gravity=-100,
    spring_strength=0.05,
    central_gravity=0.01,
    spring_length=200,
    damping=0.45,
    overlap=0,
))

def generate_set_str(showing_set: set) -> str:
    return '{' + ','.join(showing_set) + '}'

# Traverse the net nodes of PyVis and convert the data
for node in net.nodes:
    _in = generate_set_str(node.pop('IN', set()))
    _out = generate_set_str(node.pop('OUT', set()))
    _gen = generate_set_str(node.pop('GEN', set()))
    _kill = generate_set_str(node.pop('KILL', set()))
    _expr = generate_set_str([str(x) for x in node.pop('EXPR', set())])

    node['title'] = ''
    if 'instructions' in node:
        # remove 'data' key from node, and set 'title' key to the string representation of the data
        node['title'] = "\n  ".join([obj.to_briltxt() if hasattr(obj, 'to_briltxt') else str(obj) for obj in node.pop('instructions', [])])
        node['shape'] = 'box'
    node['title'] += "\n\n" + f"GEN: { _gen }" + "\n" + f"KILL: { _kill }" + "\n" + f"EXPR: { _expr }"
    node['title'] = node['title'].strip(' \t\n\r')
    # title word layout change
    # find and replace double return with dash line
    # find and replace double space with full corner single space
    node['title'] = node['title'].replace('\n\n', '\n--------\n').replace('  ', '\u3000')

    _liveness_labels = f"IN: {_in}", f"OUT: {_out}"
    node['label'] = _liveness_labels[0] + '\n' + node['label'] + '\n' + _liveness_labels[1]

    if node['id'] in (ENTRY_POINT_NAME, RETURN_POINT_NAME):
        node['color'] = 'grey'
        node['shape'] = 'circle'

for edge in net.edges:
    _reason = edge.pop('reason', None)
    if _reason:
        edge['label'] = _reason
    # _src_node = net.nodes[_map_node_name2idx( edge['from'] )]
    # _dst_node = net.nodes[_map_node_name2idx( edge['to'] )]
    # if _src_node.get('out', set()) & _dst_node.get('in', set()):
    #     _liveness_src_out_label = ', '.join(_src_node.get('out', set()) & _dst_node.get('in', set()))

net.save_graph('test.html')

In [65]:
print_blocks(bbs)

BrilFunction ::	main ( cond<bool> ) -> None: <14 instr>
anonymous block:
  BrilInstruction_Const ::	a (int) <- [const] 47
  BrilInstruction_Const ::	b (int) <- [const] 42
  BrilInstruction_EffOp ::	[br] cond {left, right} 
block "left":
  BrilInstruction_Const ::	b (int) <- [const] 1
  BrilInstruction_Const ::	c (int) <- [const] 5
  BrilInstruction_EffOp ::	[jmp]  {end} 
block "right":
  BrilInstruction_Const ::	a (int) <- [const] 2
  BrilInstruction_Const ::	c (int) <- [const] 10
  BrilInstruction_EffOp ::	[jmp]  {end} 
block "end":
  BrilInstruction_ValOp ::	d (int) <- [sub] a, c 
  BrilInstruction_EffOp ::	[print] d 

