In [2]:
import sqlglot
import mysql.connector
from functools import partial
import json

database = mysql.connector.connect(
    user='root', 
    password='password',
    host='127.0.0.1', 
    port=3307,
    database="TPCH",
)

cursor = database.cursor()

def partial_transformer(node, table_index_info):
    if isinstance(node, sqlglot.exp.Table):
        table_name = node.this.output_name
        if table_name in table_index_info:
            use_index_flag = table_index_info[table_name]["use_index_flag"]
            indexes = [ f"index_{table_name}_{column}" for column in table_index_info[table_name]["indexes"]]
        else:
            use_index_flag = True
            indexes = list()

        table_hint = sqlglot.exp.IndexTableHint()
        table_hint.set("this", "FORCE" if use_index_flag else "IGNORE")
        indexes_identifier = sqlglot.exp.Identifier()
        indexes_identifier.set("this", ", ".join(indexes))
        table_hint.set("expressions", table_hint.expressions + [indexes_identifier])
        node.set("hints", node.expressions + [table_hint])
        return node
    return node

def get_query_cost(query, table_index_info) -> float:
    expression_tree = sqlglot.parse_one(query)
    transformer = partial(partial_transformer, table_index_info=table_index_info)
    transformed_tree = expression_tree.transform(transformer)
    index_specified_query = transformed_tree.sql()
    print(index_specified_query)
    cursor.execute(f"EXPLAIN FORMAT='JSON' {index_specified_query}")
    

    query_cost = json.loads(cursor.fetchall()[0][0])["query_block"]["cost_info"]["query_cost"]

    return float(query_cost)

def get_index_cost(table_index_info) -> float:
    index_name_list = []
    for table_name in table_index_info:
        for column in table_index_info[table_name]["indexes"]:
            index_name_list.append(f"index_{table_name}_{column}")

    index_name_list_string = "('"+ "','".join(index_name_list) + "')"
        
    cursor.execute(f"SELECT ROUND(SUM(stat_value * @@innodb_page_size / 1024 / 1024), 2) size_in_mb FROM mysql.innodb_index_stats WHERE stat_name = 'size' AND index_name != 'PRIMARY' AND database_name = 'TPCH' AND index_name IN {index_name_list_string}")
    return float(cursor.fetchone()[0])

In [55]:
cursor.execute(f"EXPLAIN ANALYZE FORMAT=JSON {queries[1]}")
cursor.fetchone()

ProgrammingError: 1235 (42000): This version of MySQL doesn't yet support 'EXPLAIN ANALYZE with JSON format'

In [43]:
cursor.execute(f"EXPLAIN ANALYZE {queries[1]}")
print(json.loads(cursor.fetchall()[0][0]))

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [4]:
import pandas as pd

In [29]:
tables_list = pd.read_sql("SHOW TABLES", database)["Tables_in_TPCH"].tolist()
index_table_mapping = dict()
index_list = list()
for table in tables_list:
    query_result = pd.read_sql(f"SHOW indexes FROM {table} WHERE key_name LIKE 'index_%'", database)
    index_table_mapping[table] = query_result["Column_name"].tolist()
    index_list += query_result["Key_name"].tolist()

  tables_list = pd.read_sql("SHOW TABLES", database)["Tables_in_TPCH"].tolist()
  query_result = pd.read_sql(f"SHOW indexes FROM {table} WHERE key_name LIKE 'index_%'", database)
  query_result = pd.read_sql(f"SHOW indexes FROM {table} WHERE key_name LIKE 'index_%'", database)
  query_result = pd.read_sql(f"SHOW indexes FROM {table} WHERE key_name LIKE 'index_%'", database)


In [30]:
index_list

['index_customer_c_nationkey',
 'index_customer_c_mktsegment',
 'index_lineitem_l_partkey',
 'index_lineitem_l_suppkey',
 'index_lineitem_l_linenumber',
 'index_lineitem_l_quantity',
 'index_lineitem_l_discount',
 'index_lineitem_l_tax',
 'index_lineitem_l_returnflag',
 'index_lineitem_l_linestatus',
 'index_lineitem_l_shipdate',
 'index_lineitem_l_commitdate',
 'index_lineitem_l_receiptdate',
 'index_lineitem_l_shipinstruct',
 'index_lineitem_l_shipmode',
 'index_orders_o_custkey',
 'index_orders_o_orderstatus',
 'index_orders_o_orderpriority',
 'index_orders_o_clerk',
 'index_orders_o_shippriority',
 'index_part_p_mfgr',
 'index_part_p_brand',
 'index_part_p_type',
 'index_part_p_size',
 'index_part_p_container',
 'index_partsupp_ps_suppkey']

In [7]:
indexes = index_list[0:5]

In [24]:
table_names = [x.split("_")[1] for x in indexes]

In [26]:
table_index_info = dict()
for i in range(len(indexes)):
    index = indexes[i]
    table = table_names[i]
    index_col = index.replace(f"index_{table}_", "")
    if table in table_index_info:
        table_index_info[table]["indexes"].append(index_col)
    else:
        table_index_info[table] = {
            "use_index_flag": True,
            "indexes": [index_col],
        }
    

In [27]:
table_index_info

{'customer': {'use_index_flag': True,
  'indexes': ['c_nationkey', 'c_mktsegment']},
 'lineitem': {'use_index_flag': True,
  'indexes': ['l_partkey', 'l_suppkey', 'l_linenumber']}}

In [34]:
sql_reader = open("queries/test_queries.sql")
queries = sql_reader.read().split(";")
sql_reader.close()

In [None]:
# find columns involved in join and columns involved in where clauses

In [58]:
test_expression = sqlglot.parse_one(queries[1])


In [67]:
wheres = list(test_expression.find_all(sqlglot.exp.Where))

In [75]:
testing = wheres[0]

In [93]:
isinstance(sqlglot.exp.Where, sqlglot.exp.Subquery)

False

In [99]:
def my_walk_prune(node, parent, arg_key):
    exclude = isinstance(node, sqlglot.exp.Subquery)
    return exclude

In [105]:
for x in testing.walk(prune=my_walk_prune):
    for y in x:
        if isinstance(x[0], sqlglot.exp.Identifier):
            print(x)
    # print(x)

((IDENTIFIER this: ps_supplycost, quoted: False), (COLUMN this: 
  (IDENTIFIER this: ps_supplycost, quoted: False)), 'this')
((IDENTIFIER this: ps_supplycost, quoted: False), (COLUMN this: 
  (IDENTIFIER this: ps_supplycost, quoted: False)), 'this')
((IDENTIFIER this: ps_supplycost, quoted: False), (COLUMN this: 
  (IDENTIFIER this: ps_supplycost, quoted: False)), 'this')
((IDENTIFIER this: r_name, quoted: False), (COLUMN this: 
  (IDENTIFIER this: r_name, quoted: False)), 'this')
((IDENTIFIER this: r_name, quoted: False), (COLUMN this: 
  (IDENTIFIER this: r_name, quoted: False)), 'this')
((IDENTIFIER this: r_name, quoted: False), (COLUMN this: 
  (IDENTIFIER this: r_name, quoted: False)), 'this')
((IDENTIFIER this: n_regionkey, quoted: False), (COLUMN this: 
  (IDENTIFIER this: n_regionkey, quoted: False)), 'this')
((IDENTIFIER this: n_regionkey, quoted: False), (COLUMN this: 
  (IDENTIFIER this: n_regionkey, quoted: False)), 'this')
((IDENTIFIER this: n_regionkey, quoted: False), (C

In [72]:
list(wheres[0].find_all(sqlglot.exp.Where))

[(WHERE this: 
   (AND this: 
     (AND this: 
       (AND this: 
         (AND this: 
           (AND this: 
             (AND this: 
               (AND this: 
                 (EQ this: 
                   (COLUMN this: 
                     (IDENTIFIER this: p_partkey, quoted: False)), expression: 
                   (COLUMN this: 
                     (IDENTIFIER this: ps_partkey, quoted: False))), expression: 
                 (EQ this: 
                   (COLUMN this: 
                     (IDENTIFIER this: s_suppkey, quoted: False)), expression: 
                   (COLUMN this: 
                     (IDENTIFIER this: ps_suppkey, quoted: False)))), expression: 
               (EQ this: 
                 (COLUMN this: 
                   (IDENTIFIER this: p_size, quoted: False)), expression: 
                 (LITERAL this: 32, is_string: False))), expression: 
             (LIKE this: 
               (COLUMN this: 
                 (IDENTIFIER this: p_type, quoted: False)), ex

In [66]:
[list(x.find_all(sqlglot.exp.Identifier)) for x in ]

[[(IDENTIFIER this: ps_supplycost, quoted: False),
  (IDENTIFIER this: r_name, quoted: False),
  (IDENTIFIER this: n_regionkey, quoted: False),
  (IDENTIFIER this: r_regionkey, quoted: False),
  (IDENTIFIER this: s_nationkey, quoted: False),
  (IDENTIFIER this: n_nationkey, quoted: False),
  (IDENTIFIER this: ps_supplycost, quoted: False),
  (IDENTIFIER this: partsupp, quoted: False),
  (IDENTIFIER this: supplier, quoted: False),
  (IDENTIFIER this: nation, quoted: False),
  (IDENTIFIER this: region, quoted: False),
  (IDENTIFIER this: p_type, quoted: False),
  (IDENTIFIER this: p_size, quoted: False),
  (IDENTIFIER this: r_name, quoted: False),
  (IDENTIFIER this: p_partkey, quoted: False),
  (IDENTIFIER this: ps_partkey, quoted: False),
  (IDENTIFIER this: s_suppkey, quoted: False),
  (IDENTIFIER this: ps_suppkey, quoted: False),
  (IDENTIFIER this: n_regionkey, quoted: False),
  (IDENTIFIER this: r_regionkey, quoted: False),
  (IDENTIFIER this: s_nationkey, quoted: False),
  (IDENTI

In [55]:
base_cost = get_query_cost(queries[0], dict())

/* using 1697822052 as a seed to the RNG */ SELECT l_returnflag, l_linestatus, SUM(l_quantity) AS sum_qty, SUM(l_extendedprice) AS sum_base_price, SUM(l_extendedprice * (1 - l_discount)) AS sum_disc_price, SUM(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge, AVG(l_quantity) AS avg_qty, AVG(l_extendedprice) AS avg_price, AVG(l_discount) AS avg_disc, COUNT(*) AS count_order FROM lineitem FORCE INDEX () WHERE l_shipdate <= CAST('1998-12-01' AS DATE) - INTERVAL '83' day GROUP BY l_returnflag, l_linestatus ORDER BY l_returnflag, l_linestatus


ProgrammingError: 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near ') WHERE l_shipdate <= CAST('1998-12-01' AS DATE) - INTERVAL '83' day GROUP BY l_' at line 1

In [39]:
print(queries[0])

-- using 1697822052 as a seed to the RNG



select
	l_returnflag,
	l_linestatus,
	sum(l_quantity) as sum_qty,
	sum(l_extendedprice) as sum_base_price,
	sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
	sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
	avg(l_quantity) as avg_qty,
	avg(l_extendedprice) as avg_price,
	avg(l_discount) as avg_disc,
	count(*) as count_order
from
	lineitem
where
	l_shipdate <= date '1998-12-01' - interval '83' day
group by
	l_returnflag,
	l_linestatus
order by
	l_returnflag,
	l_linestatus


In [7]:
action_space_mapping = {
    0: "do_nothing",
}

max_key = 0
for index in index_list:
    max_key += 1
    action_space_mapping[max_key] = f"add_{index}"

for index in index_list:
    max_key += 1
    action_space_mapping[max_key] = f"remove_{index}"

In [8]:
action_space_mapping

{0: 'do_nothing',
 1: 'add_index_customer_c_nationkey',
 2: 'add_index_customer_c_mktsegment',
 3: 'add_index_lineitem_l_partkey',
 4: 'add_index_lineitem_l_suppkey',
 5: 'add_index_lineitem_l_linenumber',
 6: 'add_index_lineitem_l_quantity',
 7: 'add_index_lineitem_l_discount',
 8: 'add_index_lineitem_l_tax',
 9: 'add_index_lineitem_l_returnflag',
 10: 'add_index_lineitem_l_linestatus',
 11: 'add_index_lineitem_l_shipdate',
 12: 'add_index_lineitem_l_commitdate',
 13: 'add_index_lineitem_l_receiptdate',
 14: 'add_index_lineitem_l_shipinstruct',
 15: 'add_index_lineitem_l_shipmode',
 16: 'add_index_orders_o_custkey',
 17: 'add_index_orders_o_orderstatus',
 18: 'add_index_orders_o_orderpriority',
 19: 'add_index_orders_o_clerk',
 20: 'add_index_orders_o_shippriority',
 21: 'add_index_part_p_mfgr',
 22: 'add_index_part_p_brand',
 23: 'add_index_part_p_type',
 24: 'add_index_part_p_size',
 25: 'add_index_part_p_container',
 26: 'add_index_partsupp_ps_suppkey',
 27: 'remove_index_customer_

In [9]:
import tensorflow as tf
import numpy as np
from tensorflow import keras

from collections import deque
import time
import random

2023-11-11 10:40:41.686427: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [10]:
RANDOM_SEED = 5
tf.random.set_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [24]:
train_episodes = 5
test_episodes = 1

In [12]:
observation_space_mapping = {   
    x:x for x in range(100)
}

In [27]:
class DbEnv():
    def __init__(self):
        self.seed = RANDOM_SEED
        self.observation_space = list(observation_space_mapping.keys())
        self.action_space = list(action_space_mapping.keys())
        self.current_indexes = []
        self.training_complete = False
    def reset(self):
        if not self.training_complete:
            self.current_indexes = []
    def step(self, action):
        self.current_indexes.append(action)
        new_observation = list(observation_space_mapping.values())
        reward = random.randint(0, 5)
        done = False
        info = 1
        return new_observation, reward, done, info
            
    def close(self):
        self.training_complete=True

In [28]:
env = DbEnv()

In [29]:
batch_size = 50

In [30]:
def agent(state_shape, action_shape):
    """ The agent maps X-states to Y-actions
    e.g. The neural network output is [.1, .7, .1, .3]
    The highest value 0.7 is the Q-Value.
    The index of the highest action (0.7) is action #1.
    """
    learning_rate = 0.001
    init = tf.keras.initializers.HeUniform()
    model = keras.Sequential()
    model.add(keras.layers.Dense(24, input_shape=state_shape, activation='relu', kernel_initializer=init))
    model.add(keras.layers.Dense(12, activation='relu', kernel_initializer=init))
    model.add(keras.layers.Dense(action_shape, activation='linear', kernel_initializer=init))
    model.compile(loss=tf.keras.losses.Huber(), optimizer=tf.keras.optimizers.Adam(lr=learning_rate), metrics=['accuracy'])
    return model


def get_qs(model, state, step):
    return model.predict(state.reshape([1, state.shape[0]]))[0]


def train(env, replay_memory, model, target_model, done):
    learning_rate = 0.7 # Learning rate
    discount_factor = 0.618

    MIN_REPLAY_SIZE = 1000
    if len(replay_memory) < MIN_REPLAY_SIZE:
        return

    batch_size = 64 * 2
    error_flag = True
    counter = 1
    while error_flag:
        try:
            mini_batch = random.sample(replay_memory, batch_size)
            current_states = np.array([transition[0] for transition in mini_batch])
            error_flag = False
        except:
            print(f"Failure occured for {counter}")
            counter += 1


    current_qs_list = model.predict(current_states)
    new_current_states = np.array([transition[3] for transition in mini_batch])
    future_qs_list = target_model.predict(new_current_states)

    X = []
    Y = []
    for index, (observation, action, reward, new_observation, done) in enumerate(mini_batch):
        if not done:
            max_future_q = reward + discount_factor * np.max(future_qs_list[index])
        else:
            max_future_q = reward

        current_qs = current_qs_list[index]
        current_qs[action] = (1 - learning_rate) * current_qs[action] + learning_rate * max_future_q

        X.append(observation)
        Y.append(current_qs)
    model.fit(np.array(X), np.array(Y), batch_size=batch_size, verbose=0, shuffle=True)



In [36]:
env.observation_space

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99]

In [33]:
epsilon = 1 # Epsilon-greedy algorithm in initialized at 1 meaning every step is random at the start
max_epsilon = 1 # You can't explore more than 100% of the time
min_epsilon = 0.01 # At a minimum, we'll always explore 1% of the time
decay = 0.01

# 1. Initialize the Target and Main models
# Main Model (updated every 4 steps)
model = agent([len(env.observation_space)], len(env.action_space))
# Target Model (updated every 100 steps)
target_model = agent([len(env.observation_space)], len(env.action_space))
target_model.set_weights(model.get_weights())

replay_memory = deque(maxlen=50_000)

target_update_counter = 0

# X = states, y = actions
X = []
y = []

steps_to_update_target_model = 0

for episode in range(train_episodes):
    print(f"Currently running episode {episode}")
    steps = 0
    total_training_rewards = 0
    observation = env.reset()
    steps = 0
    while steps < 10:
        steps += 1
        steps_to_update_target_model += 1
        random_number = np.random.rand()
        # 2. Explore using the Epsilon Greedy Exploration Strategy
        if random_number <= epsilon:
            # Explore
            action = random.choice(env.action_space)
        else:
            # Exploit best known action
            # model dims are (batch, env.observation_space.n)
            encoded = observation
            # encoded_reshaped = encoded.reshape([1, encoded.shape[0]])
            predicted = model.predict(encoded).flatten()
            action = np.argmax(predicted)
        new_observation, reward, done, info = env.step(action)
        replay_memory.append([observation, action, reward, new_observation, done])

        # 3. Update the Main Network using the Bellman Equation
        if steps_to_update_target_model % 4 == 0 or done:
            mini_batch = train(env, replay_memory, model, target_model, done)

        observation = new_observation
        total_training_rewards += reward

        if done:
            print('Total training rewards: {} after n steps = {} with final reward = {}'.format(total_training_rewards, episode, reward))
            total_training_rewards += 1

            if steps_to_update_target_model >= 100:
                print('Copying main network weights to the target network weights')
                target_model.set_weights(model.get_weights())
                steps_to_update_target_model = 0
            break

    epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay * episode)
env.close()




Currently running episode 0
Currently running episode 1
Currently running episode 2
Currently running episode 3
Currently running episode 4


<keras.src.engine.sequential.Sequential at 0x12ccd4370>