In [1]:
# how to generate forward graph and make it into table format
import json 
from collections import defaultdict
import math
import numpy as np

In [2]:
path = "/Users/wangyangzuo/Desktop/公司/sd_forward.json"

total_graph = json.load(open(path,'r'))

graph         = total_graph['graph']
links         = total_graph['links']

In [3]:
reverse_links = defaultdict(list)
in_degree     = defaultdict(int)
for k,v in links.items():
    for link in v:
        reverse_links[link].append(k)
        in_degree[k] += 1

In [4]:
# 隶属关系
next_nodes  = [[graph, None]]
father_dict = dict()
id_info = dict()
while next_nodes:
    node, parent = next_nodes.pop()
    key_id       = list(node.keys())[0]
    if parent:
        father_dict[key_id] = list(parent.keys())[0]
    if "children" not in node[key_id] or not node[key_id]['children']:
        id_info[key_id] = node[key_id]
        continue
    for child in node[key_id]['children']:
        next_nodes.append([child, node])

In [5]:
id_name = {i:id_info[i]["name"] for i in id_info.keys()}

In [6]:
import re
global_idx = 10000000

In [9]:

def rewrite_graph(graph, links, source_ids, target_ids ):
    pass



def bfs_search(start_id, names, preset=None):
    t_s = [start_id]
    vis = set().union().union(preset)
    ends = [id_name[str(start_id)]] + names
    res = []
    while t_s:
        t = t_s.pop(0)
        if id_name[str(t)] != ends[len(res)]:
            return []
        if id_name[str(t)] == ends[len(res)]:
            res.append(t)
        vis.add(t)
        if len(res) == len(ends):
            return res
        if str(t) in links:
            for link in links[str(t)]:
                if link not in vis:
                    t_s.append(int(link))
    return res

def bfs_search_bwd(end_id, names, preset=None):
    t_s = [end_id]
    vis = set().union(preset)
    ends = [id_name[str(end_id)]] + names
    res = []
    while t_s:
        t = t_s.pop(0)
        if id_name[str(t)] != ends[len(res)]:
            return []
        if id_name[str(t)] == ends[len(res)]:
            res.append(t)
        vis.add(t)
        if len(res) == len(ends):
            return res
        if int(t) in reverse_links:
            for link in reverse_links[int(t)]:
                if link not in vis:
                    t_s.append(link)
    return res

def help_parse_pattern(pattern:str):
    assert "(" not in pattern
    top_down = True if "->" in pattern else False
    pattern = pattern.split("<-") if not top_down else pattern.split("->")
    return pattern, top_down

def parse_pattern(pattern:str):
    # to->linear->linear->to->(mul<-linear)->add 
    # 返回 [-1,-1, [to, linear, linear, to, mul, linear, add], top_down], [0,4 [mul, linear], down_top]
    res = []
    content = re.findall(r'\(.*?\)', pattern)
    if not content:
        print("simple pattern")
        p, top_down = help_parse_pattern(pattern)
        return [[-1,-1,p,top_down]]
    for idx, c in enumerate(content):
        pattern = pattern.replace(c, "*")
    fp, ftop_down = help_parse_pattern(pattern)
    for idx, c in enumerate(content):
        idx = fp.index("*", idx)
        p, top_down = help_parse_pattern(c[1:-1])
        print(p, fp, idx)
        fp[idx] = p[0]
        res.append([0, idx, p, top_down])
    return [[-1,-1,fp,ftop_down]] + res

# print(parse_pattern("to->linear->linear->to->(mul<-linear)->add"))

def search_cur_patten(graph, links, pattern):
    pre, node_id, p, top_down = pattern
    start_name = p[0]
    the_other  = p[1:]
    next_nodes  = [[graph, None]]
    pre_set = set() if pre == -1 else set(pre)
    while next_nodes:
        node, parent = next_nodes.pop(0)
        key_id       = list(node.keys())[0]
        if (node_id == -1 and node[key_id]['name'] == start_name) or int(node_id) == int(key_id):
            res = bfs_search_bwd(int(key_id), the_other, pre_set) if not top_down else bfs_search(int(key_id), the_other, pre_set)
            if not res: continue
            return res
        if "children" not in node[key_id] or not node[key_id]['children']:
            continue
        for child in node[key_id]['children']:
            next_nodes.append([child, node])

def search_pattern(graph, links, raw_pattern):
    pattern = parse_pattern(raw_pattern)
    print(pattern)
    res = search_cur_patten(graph, links, pattern[0])
    print(res)
    for p in pattern[1:]:
        p[0] = res
        idx = p[1]
        p[1] = res[idx]
        tmp = search_cur_patten(graph, links, p)
        for i in tmp:
            if i not in res:
                res.append(i)
    return res

def build_a_fuse_op(name, depth=None, input_shape=None, output_shape=None, inputs_dtypes=None, outputs_dtypes=None,  comment=""):
    # name, depth, idx, shape, inputs, outputs
    # ["attention", [["baddbmm",1], ["baddbmm",2], ["bmm",0]], ["bmm",0]]
    global global_idx
    idx = global_idx
    global_idx += 1
    t = {
        "name": name,
        "depth": depth,
        "idx": idx,
        "input_shape" : input_shape,
        "output_shape": output_shape,
        "input_dtype": inputs_dtypes,
        "output_dtype": outputs_dtypes,
        "comment": comment,
        "children": None
    }
    return {str(idx): t}

def handle_link(cur_id, replace_id=None):
    # if not replace id then delete all links
    if replace_id:
        father_ids = reverse_links[cur_id]
        for father_id in father_ids:
            idx = links[father_id].index(str(cur_id))
            links[father_id][idx] = int(replace_id)
        links[str(replace_id)] = links[str(cur_id)]
    else:
        father_ids = reverse_links[cur_id]
        for father_id in father_ids:
            idx = links[father_id].index(str(cur_id))
            links[father_id].pop(idx)
        del links[str(cur_id)]

def build_attention_op(badbmm_op):
    depth = badbmm_op['depth']
    shape = badbmm_op['input_shape']
    input_shape  = [ shape[1], shape[2] ,shape[2]]
    output_shape = [ shape[1]]
    input_dtype = badbmm_op['input_dtype']
    output_dtype = badbmm_op['output_dtype']
    comment = "fuse op"
    return build_a_fuse_op("attention", depth, input_shape, output_shape, input_dtype, output_dtype, comment)

def attention_match_and_rewrite(graph, links, raw_pattern):
    res            = search_pattern(graph, links, raw_pattern)
    if not res: return graph, links, False
    next_nodes  = [[graph, None]]
    attenion_op = None
    flag = False
    while next_nodes:
        node, parent = next_nodes.pop(0)
        key_id       = list(node.keys())[0]
        parent_id    = list(parent.keys())[0] if parent else None
        if int(key_id) == int(res[0]):
            flag = True
            attenion_op = build_attention_op(node[key_id])
            idx = parent[parent_id]['children'].index(node)
            parent[parent_id]['children'][idx] = attenion_op
            attention_op_id = list(attenion_op.keys())[0]
            handle_link(key_id, attention_op_id)
        elif int(key_id) == int(res[-1]):
            parent[parent_id]['children'].remove(node)
            handle_link(key_id, attention_op_id)
        elif key_id in res:
            parent[parent_id]['children'].remove(node)
            handle_link(key_id)
        if "children" not in node[key_id] or not node[key_id]['children']:
            continue
        for child in node[key_id]['children']:
            next_nodes.append([child, node])
    return graph, links, flag

# def attention_before_trans(graph, links, patterns=None):
#     if not patterns:
#         patterns = "reshape->permute->reshape->transpose"
#     res = search_pattern(graph, links, patterns)
#     next_nodes = [[graph, None]]
#     flag = False

#     return graph, links

# while 1:
#     graph, links, flag = attention_match_and_rewrite(graph, links, "baddbmm->softmax->to->bmm")
#     if not flag:
#         break

search_pattern(graph, links, "reshape->permute->reshape->transpose")

# print(search_patten( graph, links, "to->linear->linear->to->mul->(add<-linear)"))
# search_pattern(graph, links, "baddbmm->softmax->to->bmm")
# search_cur_patten(graph, links, "baddbmm->softmax->to->bmm")
# search_cur_patten(graph, links, "to->linear->linear->to->(mul<-linear)->add")
# search_cur_patten(graph, links, "add<-linear")
# (-1,-1,[to, linear, linear, to, mul, linear, add], top_down), (0,4 [mul, linear], down_top)
# (a<-d->e)->b->c forbid
# (a<-(f->g))->b->c->d->e allow 

simple pattern
[[-1, -1, ['reshape', 'permute', 'reshape', 'transpose'], True]]
[123144830548560, 123144830566544, 123144830568016, 123144830573904]


[123144830548560, 123144830566544, 123144830568016, 123144830573904]

In [22]:
with open("new_graph2.json", "w") as f:
    json.dump(graph, f, indent=4)
# graph

In [18]:
reverse_links

defaultdict(list,
            {123145429053136: ['123145429082576'],
             123145429030608: ['123145429053136'],
             123144832097936: ['123145450072400'],
             123144832099216: ['123144832097936'],
             123144832295312: ['123144832099216'],
             123144832293520: ['123145429030608'],
             123144832296528: ['123144832293520', '123144832295312'],
             123144832000400: ['123144832296528'],
             123145433053392: ['123144832000400'],
             123145433051600: ['123144832000400'],
             123145433049296: ['123145433053392', '123145433051600'],
             123145433055568: ['123145433049296'],
             123145433057168: ['123145433049296'],
             123145433058768: ['123145433055568', '123145433057168'],
             123145432526288: ['123145433058768'],
             123144832100624: ['123145432526288'],
             123144832105744: ['123144832100624'],
             123144832388432: ['123144832105744'],
       

In [46]:
father_dict

{'123145447308112': '123145450164432',
 '123144820935696': '123145447308112',
 '123144820938448': '123144820935696',
 '123144820933840': '123145447308112',
 '123144820935888': '123144820933840',
 '123144823027920': '123145447308112',
 '123144820934096': '123144823027920',
 '123144825163664': '123145447308112',
 '123144820236496': '123144825163664',
 '123144820914576': '123144820236496',
 '123144820913168': '123144820236496',
 '123144820914896': '123144820913168',
 '123144820912464': '123144820236496',
 '123144820911440': '123144820236496',
 '123144820381584': '123144820236496',
 '123144820379472': '123144820236496',
 '123144820876176': '123144820379472',
 '123144820873360': '123144820379472',
 '123144820904848': '123144820873360',
 '123144820908048': '123144820904848',
 '123144820878544': '123144820873360',
 '123144820905232': '123144820878544',
 '123144820876688': '123144820873360',
 '123144820902736': '123144820876688',
 '123144820901072': '123144820876688',
 '123144820878736': '1231

In [12]:
reverse_links = defaultdict(list)
in_degree     = defaultdict(int)

for k,v in links.items():
    for link in v:
        reverse_links[link].append(k)
        in_degree[k] += 1

In [55]:
# graph remove one node 
# links remove one node
next_nodes  = [[graph, None]]
while next_nodes:
    node, parent = next_nodes.pop(0)
    key_id       = list(node.keys())[0]
    if key_id == "123145429012432":
        # parent delete node (list) remove from lists 
        if parent:
            parent_id = list(parent.keys())[0]
            parent[parent_id]['children'].remove(node)
            if len(parent[parent_id]['children']) == 0:
                parent[parent_id]['children'] = None
    if "children" not in node[key_id] or not node[key_id]['children']:
        continue
    for child in node[key_id]['children']:
        next_nodes.append([child, node])

In [56]:
with open("new_graph.json", "w") as f:
    json.dump(graph, f)

{'123145429082576': [123145429053136],
 '123145429053136': [123145429030608],
 '123145450072400': [123144832097936],
 '123144832097936': [123144832099216],
 '123144832099216': [123144832295312],
 '123145429030608': [123144832293520],
 '123144832293520': [123144832296528],
 '123144832295312': [123144832296528],
 '123144832296528': [123144832000400],
 '123144832000400': [123145433053392, 123145433051600],
 '123145433053392': [123145433049296],
 '123145433051600': [123145433049296],
 '123145433049296': [123145433055568, 123145433057168],
 '123145433055568': [123145433058768],
 '123145433057168': [123145433058768],
 '123145433058768': [123145432526288],
 '123145432526288': [123144832100624],
 '123144832100624': [123144832105744],
 '123144832105744': [123144832388432],
 '123144832388176': [123145450821776],
 '123145450821776': [123144832453904, 123145429061072, 123144821718672],
 '123145429061072': [123144832400784],
 '123144832400784': [123144832419728],
 '123144832419728': [12314483242753

In [7]:
reverse_links[123145429012432]

['123144830575376']

In [8]:
links['123145429012432']

[123144830578704]

In [9]:
links['123144830575376'] = 123144830578704

[123145429012432]