In [2]:
import itertools

from src.config import DATA_PATH_LC_SQL_TRAIN_CSV
import csv
import os
import numpy as np
import copy
from tqdm import tqdm

# # make hints


In [3]:
# # 首先是能根据train.csv获取SQL语句
def load_data(file_name):
    joins = []
    predicates = []
    tables = []

    # Load queries
    with open(file_name, 'r') as f:
        data_raw = list(list(rec) for rec in csv.reader(f, delimiter='#'))
        print(data_raw[0])
        for row in data_raw:
            tables.append(row[0].split(','))
            joins.append(row[1].split(','))
            predicates.append(row[2].split(','))
    return tables, joins, predicates


tables, joins, predicates = load_data(DATA_PATH_LC_SQL_TRAIN_CSV)

['title t,movie_info_idx mi_idx', 't.id=mi_idx.movie_id', 't.kind_id,=,7,mi_idx.info_type_id,>,99', '283812']


In [221]:
def cartesian(array):
    """
        做笛卡尔基
    """
    return list(itertools.product(*array))


In [223]:
def generate_scan_hints(tables):
    """
        生成Scan Hint: Seq Scan 和 Index Scan
        tables形如: ["person p", "table t"]
    """
    scan_methods = ["SeqScan({})", "IndexScan({})"]
    hint_candidate = []
    for table in tables:
        # # 取别名
        table_alias = table.split(" ")[1]
        table_candidate = list(map(lambda method: method.format(table_alias), scan_methods))
        hint_candidate.append(table_candidate)
    # # hint_candidate:
    #   [
    #       ['SeqScan(p)', 'IndexScan(p)'],  // p的两种方法
    #       ['SeqScan(t)', 'IndexScan(t)']   // t的两种方法
    #   ]
    candidates = list(map(" ".join, cartesian(hint_candidate)))
    return candidates


generate_scan_hints(["person p", "table t"])

['SeqScan(p) SeqScan(t)',
 'SeqScan(p) IndexScan(t)',
 'IndexScan(p) SeqScan(t)',
 'IndexScan(p) IndexScan(t)']

In [230]:
def add_one_rel(cur, join_tables):
    """
        添加一个关系
    """
    extended_order = []
    for table in join_tables:
        if table not in cur:
            extended_order.extend([
                ["(", *cur, table, ")"],
                ["(", table, *cur, ")"]
            ])
    return extended_order


add_one_rel(['b', 'c'], ['a'])

[['(', 'b', 'c', 'a', ')'], ['(', 'a', 'b', 'c', ')']]

In [232]:
def generate_join_order_hints(tables):
    """
        不改了
    """
    # # 取表的别名
    table_alias = [x.split(" ")[1] for x in tables]
    # #
    str_order_length = 3 * len(tables) - 2
    join_orders = []
    starter = copy.deepcopy(table_alias)
    stack = [[each] for each in starter]
    while len(stack) != 0:
        cur = stack.pop(0)
        if len(cur) < str_order_length:
            extended_orders = add_one_rel(cur, table_alias)
            stack.extend(extended_orders)
        else:
            join_orders.append(cur)
    str_join_orders = [" ".join(each) for each in join_orders]
    str_join_orders = set(str_join_orders)  # # 去重
    # # 放表的顺序
    join_orders_string = list(map("Leading ({})".format, str_join_orders))
    return join_orders_string, join_orders


generate_join_order_hints(["person p", "table t", "ok o"])

(['Leading (( o ( t p ) ))',
  'Leading (( ( p t ) o ))',
  'Leading (( ( p o ) t ))',
  'Leading (( t ( o p ) ))',
  'Leading (( ( o t ) p ))',
  'Leading (( ( o p ) t ))',
  'Leading (( ( t p ) o ))',
  'Leading (( ( t o ) p ))',
  'Leading (( o ( p t ) ))',
  'Leading (( p ( o t ) ))',
  'Leading (( p ( t o ) ))',
  'Leading (( t ( p o ) ))'],
 [['(', '(', 'p', 't', ')', 'o', ')'],
  ['(', 'o', '(', 'p', 't', ')', ')'],
  ['(', '(', 't', 'p', ')', 'o', ')'],
  ['(', 'o', '(', 't', 'p', ')', ')'],
  ['(', '(', 'p', 'o', ')', 't', ')'],
  ['(', 't', '(', 'p', 'o', ')', ')'],
  ['(', '(', 'o', 'p', ')', 't', ')'],
  ['(', 't', '(', 'o', 'p', ')', ')'],
  ['(', '(', 't', 'p', ')', 'o', ')'],
  ['(', 'o', '(', 't', 'p', ')', ')'],
  ['(', '(', 'p', 't', ')', 'o', ')'],
  ['(', 'o', '(', 'p', 't', ')', ')'],
  ['(', '(', 't', 'o', ')', 'p', ')'],
  ['(', 'p', '(', 't', 'o', ')', ')'],
  ['(', '(', 'o', 't', ')', 'p', ')'],
  ['(', 'p', '(', 'o', 't', ')', ')'],
  ['(', '(', 'o', 'p', ')',

(['Leading (( o ( t p ) ))',
  'Leading (( ( p t ) o ))',
  'Leading (( ( p o ) t ))',
  'Leading (( t ( o p ) ))',
  'Leading (( ( o t ) p ))',
  'Leading (( ( o p ) t ))',
  'Leading (( ( t p ) o ))',
  'Leading (( ( t o ) p ))',
  'Leading (( o ( p t ) ))',
  'Leading (( p ( o t ) ))',
  'Leading (( p ( t o ) ))',
  'Leading (( t ( p o ) ))'],
 [['(', '(', 'p', 't', ')', 'o', ')'],
  ['(', 'o', '(', 'p', 't', ')', ')'],
  ['(', '(', 't', 'p', ')', 'o', ')'],
  ['(', 'o', '(', 't', 'p', ')', ')'],
  ['(', '(', 'p', 'o', ')', 't', ')'],
  ['(', 't', '(', 'p', 'o', ')', ')'],
  ['(', '(', 'o', 'p', ')', 't', ')'],
  ['(', 't', '(', 'o', 'p', ')', ')'],
  ['(', '(', 't', 'p', ')', 'o', ')'],
  ['(', 'o', '(', 't', 'p', ')', ')'],
  ['(', '(', 'p', 't', ')', 'o', ')'],
  ['(', 'o', '(', 'p', 't', ')', ')'],
  ['(', '(', 't', 'o', ')', 'p', ')'],
  ['(', 'p', '(', 't', 'o', ')', ')'],
  ['(', '(', 'o', 't', ')', 'p', ')'],
  ['(', 'p', '(', 'o', 't', ')', ')'],
  ['(', '(', 'o', 'p', ')',

In [217]:
def construct_sql(table, join, predicates, method="explain"):
    tables = ", ".join(table)
    if join != [""] and predicates != [""]:
        joins = " and ".join(join)
        sql = method + " select count(*) from {} where {} and {}"
    elif join != [""] and predicates == [""]:
        joins = " and ".join(join)
        sql = method + " select count(*) from {} where {} {}"
    elif join == [""] and predicates != [""]:
        joins = ""
        sql = method + " select count(*) from {} where {} {}"
    else:
        joins = ""
        sql = method + " select count(*) from {} {} {}"
    l = []
    for n in range(len(predicates) // 3):
        l.append(' '.join(predicates[n * 3:n * 3 + 3]))
    predicates = " and ".join(l)
    return sql.format(tables, joins, predicates) + ";"



In [219]:
# %%
def parse_order(order):
    """
        就是获取一个排序
    """
    left = 0
    right = len(order) - 1
    parsed_order = []
    while left < right:
        if order[left] == "(" and order[right] == ")":
            left += 1
            right -= 1
        elif order[left] == "(":
            parsed_order.insert(0, order[right])
            right -= 1
        elif order[right] == ")":
            parsed_order.insert(0, order[left])
            left += 1
        else:
            parsed_order.insert(0, order[right])
            parsed_order.insert(0, order[left])
            left += 1
            right -= 1
    return parsed_order


def generate_join_method_hints_from_orders(join_order_hints, join_orders_list):
    """
        join_order_hints: ['Leading (( t mi_idx ))', 'Leading (( mi_idx t ))']
        join_orders_list: [
            ['(', 't', 'mi_idx', ')'],
            ['(', 'mi_idx', 't', ')'],
            ['(', 'mi_idx', 't', ')'],
            ['(', 't', 'mi_idx', ')']
         ]

    """
    join_methods = ["NestLoop({})", "MergeJoin({})", "HashJoin({})"]

    join_hints = []

    for order_hint, order in zip(join_order_hints, join_orders_list):
        parsed_order = parse_order(order)
        # # JOIN ORDER
        join_order = []
        for idx in range(2, len(parsed_order) + 1):
            join_order.append(" ".join(parsed_order[0:idx]))
        # #
        join_candidate = []
        for level in join_order:
            join_candidate.append([each.format(level) for each in join_methods])
        candidates = list(map(lambda x: " ".join(x), cartesian(join_candidate)))
        join_hints.extend(list(map(lambda each: f"{each} {order_hint}", candidates)))
    if not join_hints:
        join_hints = [""]
    return join_hints


In [220]:
def generate_hint_queries(table, join, predicate, command="explain"):
    # # 获取scan_hints提示
    scan_hints = generate_scan_hints(table)
    # # 生成 Join Order的Hint
    join_order_hints, join_orders = generate_join_order_hints(table)
    join_hints = generate_join_method_hints_from_orders(join_order_hints, join_orders)

    # # 生成sql
    sql = construct_sql(table, join, predicate, command)
    # # 通过笛卡尔积排列组合生成所有的hints
    queries = list(map(lambda each: f"/*+ {each} */ {sql}", map(" ".join, cartesian([scan_hints, join_hints]))))

    return queries, sql


In [216]:
for idx, (table, join, predicate) in enumerate(zip(tables, joins, predicates)):
    # query_idx = 9
    queries_with_hint, sql = generate_hint_queries(table, join, predicate, command="explain analyse")

    if idx > 10: break