In [1]:
import pydot

In [2]:
query = \
"EXPLAIN \
SELECT c.customerNumber, COUNT(*) \
FROM customers c, orders o, orderdetails od \
WHERE c.customerNumber = o.customerNumber AND o.orderNumber = od.orderNumber AND od.quantityOrdered > 1 \
GROUP BY c.customerNumber \
HAVING COUNT(*) > 4;"

In [3]:
result = [('HashAggregate  (cost=121.11..122.63 rows=41 width=16)',),
 ('  Group Key: c.customernumber',),
 ('  Filter: (count(*) > 4)',),
 ('  ->  Hash Join  (cost=17.08..98.64 rows=2996 width=8)',),
 ('        Hash Cond: (o.customernumber = c.customernumber)',),
 ('        ->  Hash Join  (cost=11.34..84.75 rows=2996 width=8)',),
 ('              Hash Cond: (od.ordernumber = o.ordernumber)',),
 ('              ->  Seq Scan on orderdetails od  (cost=0.00..65.45 rows=2996 width=8)',),
 ('                    Filter: (quantityordered > 1)',),
 ('              ->  Hash  (cost=7.26..7.26 rows=326 width=16)',),
 ('                    ->  Seq Scan on orders o  (cost=0.00..7.26 rows=326 width=16)',),
 ('        ->  Hash  (cost=4.22..4.22 rows=122 width=8)',),
 ('              ->  Seq Scan on customers c  (cost=0.00..4.22 rows=122 width=8)',)]

In [4]:
for res in result:
    print(res[0])

HashAggregate  (cost=121.11..122.63 rows=41 width=16)
  Group Key: c.customernumber
  Filter: (count(*) > 4)
  ->  Hash Join  (cost=17.08..98.64 rows=2996 width=8)
        Hash Cond: (o.customernumber = c.customernumber)
        ->  Hash Join  (cost=11.34..84.75 rows=2996 width=8)
              Hash Cond: (od.ordernumber = o.ordernumber)
              ->  Seq Scan on orderdetails od  (cost=0.00..65.45 rows=2996 width=8)
                    Filter: (quantityordered > 1)
              ->  Hash  (cost=7.26..7.26 rows=326 width=16)
                    ->  Seq Scan on orders o  (cost=0.00..7.26 rows=326 width=16)
        ->  Hash  (cost=4.22..4.22 rows=122 width=8)
              ->  Seq Scan on customers c  (cost=0.00..4.22 rows=122 width=8)


In [5]:
# Note that the order of operators is important
operators = ['HashAggregate', 'Hash Join', 'Hash', 'Seq Scan']

In [6]:
class TreeNode(object):
    def __init__(self, title, val, description):
        self.title = title
        self.val = val
        self.description = description
        self.parent = None
        self.children = []
        self.pydot_node = None

    def setParent(self, node):
        self.parent = node
        
    def addChild(self, node):
        self.children.append(node)
        node.setParent(self)
    
    def to_string(self):
        return self.title + '\n' + self.description


In [7]:
def getNumOfPrecedingSpaces(line):
    for i in range(len(line)):
        if line[i] != ' ':
            break  
    return i

def getOperatorFromLine(line):
    # Note that the order of operators is important
    operators = ['HashAggregate', 'Hash Join', 'Hash', 'Seq Scan']
    for op in operators:
        if op in line:
            description = line.split(op)[1]
            return op, description
    return 'Unknown Operator', 'Could not find operator'

In [8]:
# class HashJoinOperator:
#     def __init__(hash_condition):
#         r1, r2, v1, v2

In [9]:
def parseQuery(result):
    all_nodes = []
    stack = []
    for i, res in enumerate(result):
        line = res[0]
        val = getNumOfPrecedingSpaces(line)
        # signals the start of an operator
        if i == 0 or '->' in line:
            op, description = getOperatorFromLine(line)
            node = TreeNode(op, val, description)

            while len(stack) > 0 and stack[-1].val >= val:
                stack.pop()

            if len(stack) > 0:
                # assign res[0] as a child of stack[-1]
                stack[-1].addChild(node)

            all_nodes.append(node)
            stack.append(node)          

        # if the line is not an operator, it belongs as a description of the previous operator
        else:
            stack[-1].description += '\n'
            line = line.replace(':', '') #pydot doesn't print out semicolons
            stack[-1].description += line[val:]
    return all_nodes

In [10]:
def printTreeBFS(root):
    frontier = [root]
    i = 0
    while frontier:
        next_level = []
        for n in frontier:
            print("level {}, {}".format(i, n.to_string()))
            next_level += n.children
        frontier = next_level
        i += 1

In [12]:
def plotBFS(root, graph):
    frontier = [root]
    i = 0
    while frontier:
        next_level = []
        for u in frontier:
            #print("level {}, val {}, title {}".format(i, n.val, n.title))
            for v in u.children: 
                edge = pydot.Edge(u.pydot_node, v.pydot_node)
                graph.add_edge(edge)
                next_level.append(v)
        frontier = next_level
        i += 1

In [14]:
def plotQueryTree(result, filename='test.png'):
    all_nodes = parseQuery(result)
    graph = pydot.Dot(graph_type='graph')
    for node in all_nodes:
        node.pydot_node = pydot.Node(node.to_string(), shape='box')
        graph.add_node(node.pydot_node)
    plotBFS(all_nodes[0], graph)
    graph.write_png(filename)
    return filename

In [15]:
plotQueryTree(result)

'test.png'

![title](test.png)


In [16]:
class Operator(object):
	def __init__(op_name, cost_estimate, description):
		# op_name in ['HashAggregate', 'Hash Join', 'Nested Loop', 'Index Scan', 'Hash', 'Seq Scan']
		# cost_estimate is a string, eg cost=2936.20..2937.73 rows=41 width=16
		# description is just a description for now
		self.op_name = op_name
		self.description = description
		parseCost(cost_estimate)

	def parseCost(cost_estimate):
		parts = cost_estimate.split(' ')
		costs = parseHeader(parts[0], header='cost=')
		min_cost, max_cost = costs.split('..')
		rows = parseHeader(parts[1], header='rows=')

		self.est_min_cost = float(min_cost)
		self.est_max_cost = float(max_cost)
		self.rows = int(rows)

	def parseHeader(line, header):
		n = len(header)
		if line[:n] == header:
			return line[n:]
		return line

In [None]:
op_line = 'Hash Join  (cost=11.34..84.75 rows=2996 width=8)',)'