In [3]:
import src.config as config
import random
import typing
from src.encoding import TreeBuilder, SQLEncoder
from src.hints import HyperQO
from src.mcts import MCTSHinterSearch
from src.net import TreeNet
from src.tree_lstm import SPINN


def get_hyper_qo_instance():
    random.seed(113)

    tree_builder = TreeBuilder()

    value_network = SPINN(head_num=config.NET_HEAD_NUM,
                          input_size=7 + 2,
                          hidden_size=config.NET_HIDDEN_SIZE,
                          table_num=50,
                          sql_size=40 * 40 + config.MAX_COLUMN_ID).to(config.DEVICE_NAME)

    tree_net = TreeNet(tree_builder=tree_builder, value_network=value_network)
    mcts_searcher = MCTSHinterSearch()

    hint_generator = HyperQO(tree_net=tree_net,
                             sql2vec=SQLEncoder(),
                             value_extractor=tree_builder.value_extractor,
                             mcts_searcher=mcts_searcher)
    return hint_generator


DEFAULT_HINT_GENERATOR = get_hyper_qo_instance()



load model from /home/bing/Projects/PythonProjects/HyperQO/data/model/model_checkpoint.pth


In [74]:
import collections


class DotFileGenerator:
    def __init__(self, graph_name: typing.Optional[str] = None):
        self.__nodes = dict()
        self.__edges = collections.defaultdict(set)
        self.__graph_name = "TestGraph" if graph_name is None else graph_name

    def add_node(self, v_name: str, v_data: typing.Optional[str] = None):
        if v_data is None:
            self.__nodes[v_name] = v_name
        else:
            self.__nodes[v_name] = v_data

    def add_edge(self, v1_name: str, v2_name: str):
        self.__edges[v1_name].add(v2_name)

    def to_graphviz_file(self):
        res = f"digraph {self.__graph_name}{{\n"
        for v_name, v_data in self.__nodes.items():
            if "Scan" in v_data:
                res += f"\t{v_name}[label=\"{v_data}\"][shape=box][color=red][width=3];\n"
            elif "Join" in v_data or 'Nested Loop' in v_data:
                res += f"\t{v_name}[label=\"{v_data}\"][shape=box][color=blue][width=3];\n"
            else:
                res += f"\t{v_name}[label=\"{v_data}\"][shape=box][width=3];\n"
        for v1_name, v2_names in self.__edges.items():
            for v2_name in v2_names:
                res += f"\t{v1_name} -> {v2_name}\n"
        res += "}"
        return res


def get_plan_tree(plan_ob: dict):
    assert "Plan" in plan_ob, "Key Error: Plan"
    dot_file_generator = DotFileGenerator()
    undefined_mark = "undefined"

    def __get_name(level: int, node_type: str):
        return f"""{node_type.replace(' ', '')}{level}"""

    def __inner_dfs(node: dict, level=0):
        node_type = node["Node Type"]
        startup_cost = node.get("Startup Cost", undefined_mark)
        total_cost = node.get("Total Cost", undefined_mark)
        plan_rows = node.get("Plan Rows", undefined_mark)

        dot_file_generator.add_node(
            f"""{__get_name(level, node_type)}""",
            f"""{node_type}\\nstart: {startup_cost}, tol: {total_cost}\\n row: {plan_rows}"""
        )

        for child in node.get("Plans", []):
            __inner_dfs(child, level + 1)
            dot_file_generator.add_edge(f"""{__get_name(level, node_type)}""",
                                        f"""{__get_name(level + 1, child['Node Type'])}""")

    __inner_dfs(plan_ob["Plan"])
    return dot_file_generator


def get_query_plan_detail(plan):
    return {
        'detail': plan,
        'tree': get_plan_tree(plan).to_graphviz_file()
    }

In [75]:
TEST_SQL = "SELECT MIN(chn.name) AS voiced_char_name,\n       MIN(n.name) AS voicing_actress_name,\n       MIN(t.title) AS kung_fu_panda\nFROM aka_name AS an,\n     char_name AS chn,\n     cast_info AS ci,\n     company_name AS cn,\n     info_type AS it,\n     keyword AS k,\n     movie_companies AS mc,\n     movie_info AS mi,\n     movie_keyword AS mk,\n     name AS n,\n     role_type AS rt,\n     title AS t\nWHERE ci.note IN ('(voice)',\n                  '(voice: Japanese version)',\n                  '(voice) (uncredited)',\n                  '(voice: English version)')\n  AND cn.country_code ='[us]'\n  AND it.info = 'release dates'\n  AND mi.info IS NOT NULL\n  AND (mi.info LIKE 'Japan:%201%'\n       OR mi.info LIKE 'USA:%201%')\n  AND n.gender ='f'\n  AND n.name LIKE '%An%'\n  AND rt.role ='actress'\n  AND t.id = mi.movie_id\n  AND t.id = mc.movie_id\n  AND t.id = ci.movie_id\n  AND t.id = mk.movie_id\n  AND mc.movie_id = ci.movie_id\n  AND mc.movie_id = mi.movie_id\n  AND mc.movie_id = mk.movie_id\n  AND mi.movie_id = ci.movie_id\n  AND mi.movie_id = mk.movie_id\n  AND ci.movie_id = mk.movie_id\n  AND cn.id = mc.company_id\n  AND it.id = mi.info_type_id\n  AND n.id = ci.person_id\n  AND rt.id = ci.role_id\n  AND n.id = an.person_id\n  AND ci.person_id = an.person_id\n  AND chn.id = ci.person_role_id\n  AND k.id = mk.keyword_id\n  AND t.title LIKE 'Kung Fu Panda%'\n  AND cn.name = 'DreamWorks Home Entertainment'\nAND k.keyword IN ('murder',\n'marvel-comics',\n'based-on-novel',\n'soothsayer')\nAND t.production_year > 2009;"


def get_all_query_plans(sql):
    (pg_plan_time, pg_latency,
     mcts_time, hinter_plan_time, mphe_time, hinter_latency,
     actual_plans, actual_time,
     chosen_leading_pairs) = DEFAULT_HINT_GENERATOR.optimize(sql)
    res = list()
    details = list()
    for i, ((mean_t, v_t, v2_t), leading, leading_utility, plan_json) in enumerate(chosen_leading_pairs):
        res.append([i + 1, leading, mean_t, leading_utility])
        details.append(get_query_plan_detail(plan_json))
    return res, details


q = get_all_query_plans(TEST_SQL)

  loss_value = self.loss_function(input=v, target=target)


In [79]:
print(q[1][3]['tree'])

digraph TestGraph{
	Aggregate0[label="Aggregate\nstart: 742.02, tol: 742.03\n row: 1"][shape=box][width=3];
	NestedLoop1[label="Nested Loop\nstart: 4.09, tol: 742.01\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop2[label="Nested Loop\nstart: 3.93, tol: 741.84\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop3[label="Nested Loop\nstart: 3.5, tol: 741.32\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop4[label="Nested Loop\nstart: 3.21, tol: 740.99\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop5[label="Nested Loop\nstart: 3.21, tol: 738.57\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop6[label="Nested Loop\nstart: 2.77, tol: 737.15\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop7[label="Nested Loop\nstart: 2.33, tol: 736.66\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop8[label="Nested Loop\nstart: 2.03, tol: 736.29\n row: 1"][shape=box][color=blue][width=3];
	NestedLoop9[label="Nested Loop\nstart: 1.59, tol: 734.88\n row: 1"][shape=bo

In [80]:

a = DotFileGenerator()
a.add_node("a", "A")
a.add_node("b", "B")
a.add_node("c", "C")
a.add_node("d", "D")
a.add_edge("a", "b")
a.add_edge("b", "c")
a.add_edge("b", "d")

print(a.to_graphviz_file())

digraph TestGraph{
	a[label="A"][shape=box][width=3];
	b[label="B"][shape=box][width=3];
	c[label="C"][shape=box][width=3];
	d[label="D"][shape=box][width=3];
	a -> b
	b -> d
	b -> c
}


In [83]:
import psqlparse

a = psqlparse.parse_dict(TEST_SQL)[0]

In [88]:
from src.encoding import RangeVar, WhereCondition


def get_sql_info(sql: str):
    sql_parse_result = psqlparse.parse_dict(sql)[0]["SelectStmt"]
    # # NOTE: BING 2023/5/18 下午9:55 获得FROM子句中的表名, 保存在table_list中
    # # ... 测试用例中的表数目必须大于等于2
    table_list = [str(RangeVar(x["RangeVar"])) for x in sql_parse_result["fromClause"]]

    # # NOTE: BING 2023/5/18 下午10:02 获得WHERE子句中的谓词, 保存在comparison_list中
    comparison_list = [str(WhereCondition(x)) for x in sql_parse_result["whereClause"]["BoolExpr"]["args"]]

    # # 获取SQL语
    return {
        '涉及到的表': table_list,
        '涉及到的谓词': comparison_list
    }


print(get_sql_info(TEST_SQL))

{'涉及到的表': ['aka_name AS an', 'char_name AS chn', 'cast_info AS ci', 'company_name AS cn', 'info_type AS it', 'keyword AS k', 'movie_companies AS mc', 'movie_info AS mi', 'movie_keyword AS mk', 'name AS n', 'role_type AS rt', 'title AS t'], '涉及到的谓词': ["ci.note IN ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)')", "cn.country_code = '[us]'", "it.info = 'release dates'", 'mi.info IS NOT NULL', "( mi.info like 'Japan:%201%' OR mi.info like 'USA:%201%')", "n.gender = 'f'", "n.name like '%An%'", "rt.role = 'actress'", 't.id = mi.movie_id', 't.id = mc.movie_id', 't.id = ci.movie_id', 't.id = mk.movie_id', 'mc.movie_id = ci.movie_id', 'mc.movie_id = mi.movie_id', 'mc.movie_id = mk.movie_id', 'mi.movie_id = ci.movie_id', 'mi.movie_id = mk.movie_id', 'ci.movie_id = mk.movie_id', 'cn.id = mc.company_id', 'it.id = mi.info_type_id', 'n.id = ci.person_id', 'rt.id = ci.role_id', 'n.id = an.person_id', 'ci.person_id = an.person_id', 'chn.id = ci.person_role_

In [85]:
for range_var in a["SelectStmt"]["fromClause"]:
    print()

aka_name an
char_name chn
cast_info ci
company_name cn
info_type it
keyword k
movie_companies mc
movie_info mi
movie_keyword mk
name n
role_type rt
title t
