# Experimental Evaluation Reproducibility Notebook "Soft and Constrained Hypertree Width"

This JupyterLab notebook contains the exact same code that was used to run the experiments shown in the paper. 

## The Scala-based Rewriting Tool 

An important part of our implementation is the Scala-based Rewriting tool that takes a CTD and encodes it into a series of SQL queries that effectively implement Yannakakis' algorithm, adapted for CTDs. This tool is based off of code from (https://doi.org/10.34726/hss.2024.120310), which was adapted to fit the exact scenario in the paper. The source code for it can be found in the same repository in which this notebook was located. 

A user needs to compile and export our version of the Scala tool, as indicated in the ReadMe inside the Scala source code. This will produce a .jar file which needs to be put in the same folder as this notebook (for simplicity), and the name of which must be entered as indicated in the comments of the cells below. 


## Used benchmarks

A crucial point that must be setup externally before these suite of experiments can be used are the respective benchmarks. The paper used the following datasets: 
* TPC-DS, using a scaling factor 10. TPC-DS can be found at its homepage: (https://www.tpc.org/tpcds/).
* Hetionet, as can be found using the data files and queries from (https://github.com/umbra-db/diamond-vldb2024). Please note the FigShare link in the ReadMe of that link to find the actual data files.
* LSQB, with a scaling factor of 10. All information on LSQB, and how to set it up, can be found in its respective paper: (https://dl.acm.org/doi/10.1145/3461837.3464516).

Any user of this notebook is expected to setup the benchmarks on a local installation of Postgres, using the provided scale factors where applicable. Where needed, the name of the exact database, the user name and the password, must be inserted for the experiments to succeed.  The comments will indicate where these configuration strings need to be appropriately adjusted. 


In [None]:
%%bash
pip3 install matplotlib
pip3 install psycopg2-binary
pip3 install networkx
pip3 install colorama
pip3 install termcolor
pip3 install py4j
pip3 install numba
pip3 install pandas
pip3 install line_profiler
pip3 install pyspark

In [None]:
import subprocess
import psycopg2
from pathlib import Path
import json
from json import JSONEncoder

from functools import partial
import sys
from py4j.java_gateway import JavaGateway
from py4j.java_collections import SetConverter, MapConverter, ListConverter
import heapq
from dataclasses import dataclass, field
from typing import Any
import multiprocessing


## Source Python Files for the implementation

In [None]:
import networkx as nx
import re
import pprint
import itertools
import colorama
from termcolor import colored
import functools
colorama.init()

class Edge(object):
    def __init__(self,V,name):
        assert(type(name) == str)
        assert(type(V) == set)
        self.V = V
        self.name = name

    def __repr__(self):
        return self.name + "(" + ",".join(map(str,self.V)) + ")"

class HyperGraph(object):
    def __init__(self):
        self.V = set()
        self.E = list()
        self.edge_dict = dict()


    def markComplete(self):
        newH = HyperGraph()
        count = 1
        for e in self.E:
            newVertex = str(-1*count)
            count = count+1
            newV = e.V.copy()
            newV.add(newVertex)
            newH.add_edge(newV, e.name)
        return newH
        
            

    def grid(n, m):
        h = HyperGraph()
        hc, vc = 0, 0
        for col in range(m-1):
            for row in range(n):
                vi = '{}.{}'.format(row, col)
                vright = '{}.{}'.format(row, col+1)
                horz_name = 'H{}'.format(hc)
                hc = hc+1
                h.add_edge(set([vi, vright]),
                           horz_name)
        for col in range(m):
            for row in range(n-1):
                vi = '{}.{}'.format(row, col)
                vdown = '{}.{}'.format(row+1, col)
                vert_name = 'V{}'.format(vc)
                vc = vc+1
                h.add_edge(set([vi, vdown]),
                           vert_name)
        return h

    def copy(self):
        h = HyperGraph()
        for en, e in self.edge_dict.items():
            h.add_edge(e.V, name=en)
        return h

    def join_copy(self, x, y):
        """Copy of self with vertices x and y joined"""
        if x not in self.V or y not in self.V:
            raise ValueError('Join vertices need to be in hypergraph')
        h = HyperGraph()
        for en, e in self.edge_dict.items():
            e2 = e.V.copy()
            if y in e2:
                e2.remove(y)
                e2.add(x)
            h.add_edge(e2, name=en)
        return h

    def toHyperbench(self):
        s = []
        for en, e in sorted(self.edge_dict.items()):
            s.append('{}({}),'.format(en, ','.join(e.V)))
        return '\n'.join(s)

    def vertex_induced_subg(self, U):
        """Induced by vertex set U"""
        h = HyperGraph()
        for en, e in self.edge_dict.items():
            e2 = e.V.copy()
            e2 = e2 & U
            if e2 != set():
                h.add_edge(e2, name=en)
        return h

    def bridge_subg(self, U):
        EC = [en for en, e in self.edge_dict.items() if
              (e.V & U) != set()]
        C = self.edge_subg(EC)

        # for each component C_i of rest, compute a special edge Sp_i
        for C_i in self.separate(U):
            # print(C_i)
            Sp_i_parts = [(e.V - U) for e in C.E if (e.V & C_i.V) != set()]
            Sp_i = set.union(*Sp_i_parts)
            C.add_special_edge(Sp_i)
        return C

    def edge_subg(self, edge_names):
        h = HyperGraph()
        for en in edge_names:
            if en not in self.edge_dict:
                raise ValueError('Edge >{}< not present in hypergraph'.format(en))
            h.add_edge(self.edge_dict[en].copy(), en)
        return h

    def fromHyperbench(fname):
        EDGE_RE = re.compile('\s*([\w:]+)\s?\(([^\)]*)\)')
        def split_to_edge_statements(s):
            x = re.compile('\w+\s*\([^\)]+\)')
            return list(x.findall(s))

        def cleanup_lines(rl):
            a = map(str.rstrip, rl)
            b = filter(lambda x: not x.startswith('%') and len(x) > 0, a)
            return split_to_edge_statements(''.join(b))

        def line_to_edge(l):
            m = EDGE_RE.match(l)
            name = m.group(1)
            e = m.group(2).split(',')
            e = set(map(str.strip, e))
            return name, e            

        with open(fname) as f:
            raw_lines = f.readlines()
        lines = cleanup_lines(raw_lines)

        hg = HyperGraph()
        for l in lines:
            edge_name, edge = line_to_edge(l)
            hg.add_edge(edge, edge_name)
        return hg

    def add_edge(self, edge, name):
        assert(type(edge) == set)
        obj = Edge(edge,name)
        self.edge_dict[name] = obj
        self.V.update(edge)
        self.E.append(obj)

    def add_special_edge(self, sp):
        SPECIAL_NAME = 'Special'
        # find a name first
        sp_name = None
        for i in itertools.count():
            candidate = SPECIAL_NAME + str(i)
            if candidate not in self.edge_dict:
                sp_name = candidate
                break
        self.add_edge(sp, sp_name)

    def remove_edge(self, name):
        e = self.edge_dict[name]
        del self.edge_dict[name]
        self.E.remove(e)

    def primal_nx(self):
        G = nx.Graph()
        G.add_nodes_from(self.V)
        for i, e in enumerate(self.E):
            for a, b in itertools.combinations(e.V, 2):
                G.add_edge(a, b)
        return G

    def incidence_nx(self, without=[]):
        G = nx.Graph()
        G.add_nodes_from(self.V)
        G.add_nodes_from(self.edge_dict.keys())
        for n, e in self.edge_dict.items():
            if n in without:
                continue
            for v in e.V:
                G.add_edge(n, v)
        return G

    def toPACE(self, special=[]):
        buf = list()
        vertex2int = {v: str(i) for i, v in enumerate(self.V, start=1)}
        buf.append('p htd {} {}'.format(len(self.V),
                                        len(self.E)))
        for i, ei in enumerate(sorted(self.edge_dict.items()), start=1):
            en, e = ei.V
            edgestr = ' '.join(map(lambda v: vertex2int[v], e))
            line = '{} {}'.format(i, edgestr)
            buf.append(line)

        if special is None:
            special = []
        for sp in special:
            if sp is None:
                continue
            edgestr = ' '.join(map(lambda v: vertex2int[v], sp))
            buf.append('s ' + edgestr)
        return '\n'.join(buf)

    def separation_subg(self, U, sep):
        C = HyperGraph()
        cover = U | sep
        for en, e in self.edge_dict.items():
            if e.V.issubset(cover) and not e.V.issubset(sep):
                C.add_edge(e.V, en)
        return C

    def separate(self, sep, only_vertices=False):
        """Returns list of components"""
        assert(type(sep) == set)
        primal = self.primal_nx()
        primal.remove_nodes_from(sep)
        comp_vertices = nx.connected_components(primal)
        if only_vertices:
            return list(comp_vertices)
        comps = [self.separation_subg(U, sep)
                 for U in comp_vertices]
        return comps

    def toVisualSC(self):
        vertex2int = {v: str(i) for i, v in enumerate(self.V, start=1)}
        edges = map(lambda e: map(lambda v: vertex2int[v], e.V), self.E)
        buf = []
        for e in edges:
            buf.append('{'+', '.join(e) + '}')
        return ' '.join(buf)

    def fancy_repr(self, hl=[]):
        edge_style = colorama.Fore.RED + colorama.Style.NORMAL
        vertex_style = colorama.Fore.YELLOW + colorama.Style.NORMAL
        hl_style = colorama.Fore.WHITE + colorama.Back.GREEN + colorama.Style.BRIGHT
        _reset = colorama.Style.RESET_ALL

        def color_vertex(v):
            if v in hl:
                return hl_style + v + _reset
            else:
                return vertex_style + v + _reset
        s = ''
        for en, e in sorted(self.edge_dict.items()):
            s += edge_style + en + _reset + '('
            s += ','.join(map(color_vertex, e.V))
            s += ')\n'
        return s

    def __repr__(self):
        return self.fancy_repr()

In [None]:

import math

class VertSet(object):
    def __init__(self,vertices):
        assert(type(vertices) == set)
        self.vertices = vertices

         
    def __hash__(self):
        finalHash = 0 
        for h in self.vertices:
            finalHash = finalHash + int(h)
        return finalHash   
        
    def __repr__(self):        
        return str(sorted(list(map(lambda s: int(s), self.vertices))))

    def __eq__(self, other):
        return type(other) == VertSet and self.vertices == other.vertices

class Block(object):
    def __init__(self,head,cover,tail):
        assert(type(head) == VertSet)
        assert(type(tail) == VertSet)
        assert(len(head.vertices.intersection(tail.vertices)) == 0) # disjoint
        self.head = head
        self.cover = cover
        self.tail = tail

    def __hash__(self):
        finalHash = 0 
        for h in self.head.vertices:
            finalHash = finalHash + hash(h)
        for t in self.tail.vertices:
            finalHash = finalHash + hash(t)
        return finalHash

    def __eq__(self, other):
        return type(other) == Block and self.head == other.head and self.tail == other.tail
            
    def __repr__(self):
        return "Block("+str(self.head)+","+str(self.tail)+","+str(self.cover)+")"


    def __lt__(self,other):
        selfVert = self.head.vertices.union(self.tail.vertices) 
        otherVert = other.head.vertices.union(other.tail.vertices) 

        return selfVert.issubset(otherVert) and self.tail.vertices.issubset(other.tail.vertices)

    # connected bags filters out any bags for which the induced subgraph over E is not connected
    def connected(self):
        coverGraph = HyperGraph()

        for e in self.cover: 
            coverGraph.add_edge(e.V,e.name)
        comps = coverGraph.separate(set())
        # if len(comps) == 1: 
        #     print("For the block " + str(self)  + " with cover "+ str(self.cover)+ " there are these components " + str(comps) )
      
        return len(comps) == 1 # connected if only one connected comp

    def index(self):
        cover = list(map(lambda e: e.name, self.cover))
        return ",".join(sorted(cover))

class Node:    
    def __init__(self,bag,cover,children, weight = 0, weight_ideal = 0):
        assert(type(bag) == VertSet)
        self.bag = bag  # set of vertices
        self.cover = cover #set of edges
        self.children = children #set of child nodes
        self.weight = weight
        self.weight_ideal = weight_ideal


    def getCoverWeight(self,node_weights):
        nuCover = []
        for e in self.cover:
            nuCover.append(e.name)        
        cover_index = ",".join(sorted(nuCover))
        return node_weights[cover_index]
        

    def NodeWeight(self,node_weights):
        if len(self.cover) == 1: 
            return 1
        else:
            sumSingleEdgeWeights = 0
            for e in self.cover:
                subcover = [e.name]
                subcover_index = ",".join(sorted(subcover))
                if node_weights[subcover_index] != 0:
                    sumSingleEdgeWeights = sumSingleEdgeWeights + ( node_weights[subcover_index] * math.log(node_weights[subcover_index]))
            return self.getCoverWeight(node_weights) + sumSingleEdgeWeights

    def ReducedSz(self,node_weights):
        # check if child has ReducedSz of 0
        for c in self.children:
            if c.ReducedSz(node_weights) <= 1:
                return 1
        return self.getCoverWeight(node_weights)


    def ScanCost(self,node_weights): 
        #check if left-most child has ReducedSz of 0
        if len(self.children) == 0:
            return 0
        if list(self.children)[0].ReducedSz(node_weights) <= 1: 
            return 0
        if self.getCoverWeight(node_weights) == 0:
            return 0
        else:
            return self.getCoverWeight(node_weights) * math.log(self.getCoverWeight(node_weights))
        
        
    def TotalWeight(self,node_weights):
        nodeWeight = self.NodeWeight(node_weights)
        self_join_costs = 0
        for c in self.children: 
            child_node_weight, scan_costs, child_sj_costs = c.TotalWeight(node_weights)
            child_weight = child_node_weight+child_sj_costs+scan_costs
            if c.ReducedSz(node_weights) == 0:
                self_join_costs = self_join_costs  + child_weight
            else:
                self_join_costs = self_join_costs  + child_weight + (c.ReducedSz(node_weights) * math.log(c.ReducedSz(node_weights)))
        
        return (nodeWeight,self.ScanCost(node_weights),self_join_costs)
        

    def removeMarkers(self):
        badVs = []
        for v in self.bag.vertices:
            if int(v) < 0:
                badVs.append(v)
        for v in badVs:
            self.bag.vertices.remove(v)
        
        for e in self.cover: 
            badVs = []
            for v in e.V:
                if int(v) < 0:
                    badVs.append(v)
            for v in badVs:
                e.V.remove(v)
        for c in self.children: 
            c.removeMarkers()
    
         
    def addChild(self,child):
        self.children.append(child)

    
    def toStringCost(self,depth,node_weights):

        node_weight,scan_cost,sj_costs = self.TotalWeight(node_weights)
        
        tabby = "\n " + "\t" * depth
        
        childrenReps = list()
        for child in self.children:
            childrenReps.append(child.toStringCost(depth+1,node_weights))

        return "Bag: " + str(self.bag) + " Cover: " + str(self.cover) + "Total Cost: " + str(node_weight + sj_costs) +   " NodeCost: " + str(node_weight) +  " ScanCost: " + str(scan_cost) +   "  SubTree Costs:" + str(sj_costs) + tabby + tabby.join(childrenReps)

        
    def toString(self,depth):

        tabby = "\n " + "\t" * depth
        

        childrenReps = list()
        for child in self.children:
            childrenReps.append(child.toString(depth+1))

        return "Bag: " + str(self.bag) + " Cover: " + str(self.cover) + tabby + tabby.join(childrenReps)
    
    def __repr__(self):        
        return self.toString(1)

class NodeEncoder(JSONEncoder):
    def default(self, o):
        return {'bag': list(o.bag.vertices),
                'cover': list([{'name': e.name, 'vertices': list(e.V)} for e in o.cover]),
                'children': [self.default(c) for c in o.children]}

@dataclass
class WeightedBasis:
    weight: int
    weight_ideal: int
    basis: Any
    def __lt__(self, other):
        return self.weight > other.weight


class CTDOpt(object):
    def __init__(self,h):
        self.H = h                   # hypergraph
        self.root_block = Block(VertSet(set()), set(), VertSet(h.V))
        self.blocks = set([self.root_block])
        self.satisfied_block = set() # indicating which blocks are satisfied
        self.head_to_blocks = dict() # mapping heads to blocks headed by them
        self.weights = dict() # maps block to weight
        self.weights[self.root_block] = sys.maxsize
        self.weights_ideal = dict()
        self.weights_ideal[self.root_block] = sys.maxsize
        self.sj_weights = dict()
        self.children = dict()
        self.top_children = dict()
        self.top_children[self.root_block] = []
        self.new_blocks = set()
        self.head_to_cover = dict() # cache the edge covers
        self.block_to_basis = dict() # mapping a satisfied block to its basis
        self.rootHead = None # cache the root head once found

    def addBlock(self,b):
        assert(type(b) == Block)
        if b in self.blocks:
            return # don't add same block twice
        self.blocks.add(b)
        self.head_to_cover[b.head] = b.cover
        self.new_blocks.add(b)
        # print("Is the head ", b.head ," hash:",hash(b.head)  ," already in the map ", list(self.head_to_blocks.keys()))
        # print("Answer: ", b.head in list(self.head_to_blocks.keys()))
        if b.head in self.head_to_blocks:
            self.head_to_blocks[b.head].append(b)
        else:
            self.head_to_blocks[b.head] = [b]            

        self.top_children[b] = []
        if len(b.tail.vertices) == 0: 
            # print("Block ",b," added as trivially sat.")
            self.satisfied_block.add(b)  # check if trivially satisifed
            block_index = b.index()
            if block_index in self.node_weights:
                self.weights[b] = self.node_weights[block_index]
                self.weights_ideal[b] = self.node_weights_ideal[block_index]
            else:
                # single edge
                self.weights[b] = 1
                self.weights_ideal[b] = 1
            self.children[b] = set()
        else:
            self.weights[b] = sys.maxsize
            self.weights_ideal[b] = sys.maxsize
        # else:
        #     self.block_dict[b] = self.hasBasis(b) # basis check

    def minimize_weights(self, topn):
        # new_blocks = blocks that were updated in the last iteration -> continue until there are no more updates
        while self.new_blocks != set():
            new = set() # keep track of newly added blocks to stop when nothing new is added
            for b in self.blocks:
                # print("checking block ", b)
                if len(b.tail.vertices) == 0:
                    # skip trivial blocks
                    continue
                bases = self.determine_bases(b, self.new_blocks)
                #print("bases: " + str(bases))
                all_children = self.top_children[b]
                for basis in bases:
                    basis.sort(key=self.reducedSz)

                    
                    
                    # print("basis: " + str(basis))
                    new_weight = self.basis_weight(b, basis)
                    new_weight_ideal = self.basis_weight_ideal(b, basis)
                    # print("new weight: " + str(new_weight) + ", old weight: " + str(self.weights[b]))
                    # basis_sum = sum(list(map(lambda b: self.weights[b], basis)))
                    weight = self.weights[b]
                    all_children.append(WeightedBasis(new_weight, new_weight_ideal, basis))
                    if new_weight < weight:
                        self.weights[b] = new_weight
                        self.weights_ideal[b] = new_weight_ideal
                        self.children[b] = basis
                        self.block_to_basis[b] = basis
                        new.add(b)
                #print("all children: ", all_children)
                self.top_children[b] = heapq.nlargest(topn, all_children)
                #print("sorted", self.top_children[b])
            self.new_blocks = new
        if self.weights[self.root_block] == sys.maxsize:
            print("no decomposition found")
            return None
        else:
            decomps = self.construct_tds(topn)
            # # decomps = [self.construct_td()]
            # print("decompositions found: ")
            # root_block = self.root_block
            # root_basis = self.children[root_block] 
            # print("real optimal weight: ", self.basis_weight(root_block, root_basis), "\n")
            # for decomp in decomps:
            #     print(decomp.toStringCost(1,self.node_weights))
            #     print("weight: ", decomp.weight, "\n")
            # print("root block children", self.children[self.root_block])
            return decomps

    def construct_td(self):
        return self.to_node(self.root_block)
    
    def construct_tds(self, topn):
        return self.to_nodes(self.root_block, topn)

    def add_weights(self, node_costs):
        self.node_weights = node_costs

    
    def add_weights_ideal(self, node_costs):
        self.node_weights_ideal = node_costs
        
    def add_sj_weights(self, sj_weights):
        self.sj_weights = sj_weights

    def block_weight(self, block):
        cover = list(map(lambda e: e.name, block.cover))
        block_index = ",".join(sorted(cover))
        if block_index in self.node_weights_ideal:
            return self.node_weights_ideal[block_index]
        else:
            return 1
            
    def sj_weight(self, from_b, to_b):
        cover_from = [e.name for e in from_b.cover]
        idx_from = ",".join(sorted(cover_from))
        cover_to = [e.name for e in to_b.cover]
        idx_to = ",".join(sorted(cover_to))
        idx = idx_from + "-" + idx_to

        if idx in self.sj_weights:
            return self.sj_weights[idx]
        else:
            return 10000000



    def getCoverWeight(self,cover):
        nuCover = []
        for e in cover:
            nuCover.append(e.name)        
        cover_index = ",".join(sorted(nuCover))
        return self.node_weights[cover_index]
        

    def get_node_weight(self,cover):
        if len(cover) == 1:
            return 1
        else:
            J_u = self.getCoverWeight(cover)
            sumSingleEdgeWeights = 0
            for e in cover:
                subcover = [e.name]
                subcover_index = ",".join(sorted(subcover))
                if self.node_weights[subcover_index] != 0:
                    sumSingleEdgeWeights = sumSingleEdgeWeights + (self.node_weights[subcover_index] * math.log(self.node_weights[subcover_index]))
            return J_u + sumSingleEdgeWeights

    #Note that ReducedAttr is assumed to be 0 here
    def reducedSz(self,block):
        if len(block.tail.vertices) == 0:
            return -1
        
        basis = self.children[block]   
        oneBasisBlock = next(iter(basis))
        basis_head = oneBasisBlock.head
        cover = self.head_to_cover[basis_head]
        
        # check if child of block has child with ReduceSz of 0
        for c in basis: 
            if len(c.tail.vertices) == 0:
                continue # trivial blocks are not nodes, hence ignored
            else:
                if self.reducedSz(c) <= 1:
                    return 1 # not 0, to encourage smaller TDs
        return  self.getCoverWeight(cover)

    def scan_cost(self,block):
        if len(block.tail.vertices) == 0:
            return 0

        
        basis = self.children[block]   
        oneBasisBlock = next(iter(basis))
        basis_head = oneBasisBlock.head
        cover = self.head_to_cover[basis_head]
        
        
        basis = self.children[block]   

        if self.reducedSz(basis[0]) <= 1:
            return 0

        if self.getCoverWeight(cover) == 0:
            return 0
        else:
            return self.getCoverWeight(cover) * math.log(self.getCoverWeight(cover))
        
  
      
        
    def basis_weight(self, block, basis):
        if len(block.tail.vertices) == 0: # base case
            return 0
        else: 
            # get weight of node
            oneBasisBlock = next(iter(basis))
            head = oneBasisBlock.head # can safely assume that basis of non-trivial block is non-empty
            root_cover = self.head_to_cover[head]
            current_node_weight = self.get_node_weight(root_cover)
            sj_costs = 0 
            # children = []
            for b in basis:
                if len(b.tail.vertices) == 0:
                    continue # trivial nodes do not contribute to weight
                else: 
                    child_basis = self.children[b]                      

                    child_cost = self.basis_weight(b,child_basis) # we can assume that any block in a basis also has a basis (or is trivial)
                    if  self.reducedSz(b) == 0:
                        sj_costs = sj_costs + child_cost
                    else :
                        sj_costs = sj_costs + child_cost + self.reducedSz(b) * math.log( self.reducedSz(b) ) 
                    # children.append(b)
                    
            # print("for cover: ", root_cover, " Node Cost ", current_node_weight, " ST Cost: ", sj_costs)
            # print(" Children: ", children)
            # for c in children:
            #     print("ReduzedSz of ", c, " = ", self.reducedSz(c))
            return current_node_weight + self.scan_cost(b) +  sj_costs
        
    def basis_weight_ideal(self, block, basis):
        basis_sum = sum(list(map(lambda b: 0 if len(b.tail.vertices) == 0 else self.weights_ideal[b], basis)))
        sj_costs = 0
        for b in basis:
            child_basis = self.children[b]
            if len(child_basis) > 0:
                child_block = next(iter(child_basis))
                sj_costs += self.sj_weight(child_block, b)
        return self.block_weight(next(iter(basis))) + sj_costs + basis_sum
        # return sj_costs + basis_sum

    # determine the bases of a block wrt. new blocks (one of the blocks has to be from new_blocks)
    # a basis is a set of blocks
    def determine_bases(self, b, new_blocks):
        bases = []
        #print("block: " + str(b))
        for head in self.head_to_blocks:
            #print("head: " + str(head))
            allBlocks = self.head_to_blocks[head]
            #print("allblocks: " + str(allBlocks))
            #headed_blocks = [x for x in allBlocks if x < b and not (x.head == b.head and x.tail == b.tail)]
            headed_blocks = [x for x in allBlocks if x < b and not (x.head == b.head and x.tail == b.tail)]

            #print("headed blocks: " + str(headed_blocks))

            if set(headed_blocks).intersection(new_blocks) == set():
                continue

            for ob in headed_blocks:
                if self.weights[ob] == sys.maxsize:
                    continue

            # 3. condition (for each component C_i', the block (B', C_i') is satisfied
            cond3 = True
            for ob in headed_blocks:
                if not ob in self.satisfied_block:
                    cond3 = False
            if cond3 == False:
                #print("cond3 broken")
                continue #3nd Condition violated (testing first for efficiency)

            # 1. condition (the tail of the block b is a subset of the union of
            # the tails and the head
            unionTails = set()
            # union of the tails' vertices
            for ob in headed_blocks:
                for v in ob.tail.vertices:
                    unionTails.add(v)
            # add the head's vertices
            for v in head.vertices:
                unionTails.add(v)
            if not b.tail.vertices.issubset(unionTails):
                #print("cond1 broken")
                continue # 1st Condition violated

            # 2. condition (each hyperedge partially contained in the tail of b has to be contained
            # in the union of the tails and the head)
            cond2 = True
            for e in self.H.E:
                if len(e.V.intersection(b.tail.vertices)) == 0:
                    continue # find other edge
                if not e.V.issubset(unionTails):
                    cond2 = False
                    #print("cond2 broken")
                    break
            if cond2 == False:
                continue # 2nd Condition violated

            # basis found!
            basis = headed_blocks
            # for ob in headed_blocks:
            #     basis.append(ob)
            bases.append(basis)
            #print("bases: " + str(bases))
        if bases != []:
            self.satisfied_block.add(b)
        return bases

    def hasBasis(self,b):
        basisFound = False
        basisWitness = None
        for B in self.head_to_blocks:
            allBlocks = self.head_to_blocks[B]
            blocks = [x for x in allBlocks if x < b]

            cond3 = True
            for ob in blocks:
                if not ob in self.satisfied_block:
                    cond3 = False
            if cond3 == False:
                continue #3nd Condition violated (testing first for efficiency)

            unionTails = set()
            for ob in blocks:
                for v in ob.tail.vertices:
                    unionTails.add(v)
            for v in B.vertices:
                unionTails.add(v)
            if not  b.tail.vertices.issubset(unionTails):
                continue # 1st Condition violated
            cond2 = True
            for e in self.H.E:
                if len(e.V.intersection(b.tail.vertices)) == 0:
                    continue # find other edge
                if not e.V.issubset(unionTails):
                    cond2 = False
                    break
            if cond2 == False:
                continue # 2nd Condition violated
            basisFound = True
            basisWitness = B
            # print("The basis of ", b , " is ", B)
            # print("The blocks headed by ", B)
            # for BB in blocks:
            #     print(str(BB)+"\n")

            break
        if basisFound == True:
            self.satisfied_block.add(b)
            self.block_to_basis[b] = basisWitness
            return True
        else:
            return False

    def rootHeadFound(self):
        for head in self.head_to_blocks:
            blocks = self.head_to_blocks[head]
            allSatisfied = True
            for b in blocks:
                if not b in self.satisfied_block:
                    allSatisfied = False
            if allSatisfied == True:
                # print("Root Head is ",head)
                self.rootHead = head
                return True
        return False


    def hasDecomp(self):
        while True:            
            changed = False
            for b in self.blocks:
                if b in self.satisfied_block:
                    continue # already marked as satisfied
                res = self.hasBasis(b)
                if res == True:
                    changed = True
                    #print("Found basis for the block ", b)
                if self.rootHeadFound():
                    # print("Found decomp!")
                    return True
            if changed == False:
                # print("Nothing has changed anymore, terminating")
                return False

    def to_node(self,block):
        if not(block in self.satisfied_block):
            # print(block, " is not satisfied")
            return None  # Nothing to return if block not satisfied
        if len(block.tail.vertices) == 0:
            return Node(block.head,self.head_to_cover[block.head],list()) # leaf node
        basis = self.block_to_basis[block]

        node_children = list()
        for block_child in self.children[block]:
            if len(block_child.tail.vertices) != 0:
                node_children.append(self.to_node(block_child))

        basis_head = basis[0].head
        return Node(basis_head,self.head_to_cover[basis_head],node_children)

    def to_nodes(self,block,topn):
        if not(block in self.satisfied_block):
            # print(block, " is not satisfied")
            return None  # Nothing to return if block not satisfied
        if len(block.tail.vertices) == 0:
            # print(block, " is trivial")
            return Node(block.head,self.head_to_cover[block.head],list()) # leaf node

        #print("top basis", self.children[block])
        #print("weight: ", self.weights[block])
        # print("top children: ", self.top_children[block])
        nodes = []
        for weighted_basis in self.top_children[block]:
            basis = weighted_basis.basis
            node_children = list()
            for block_child in weighted_basis.basis:
                if len(block_child.tail.vertices) != 0:
                    node_children.append(self.to_node(block_child))

            basis_head = list(basis)[0].head

            node_children.sort(key=lambda n : n.ReducedSz(self.node_weights))
            
            nodes.append(Node(basis_head,self.head_to_cover[basis_head],node_children, weighted_basis.weight, weighted_basis.weight_ideal))
        return nodes

    def getDecomp(self,block):
        if not(block in self.satisfied_block):
            # print(block, " is not satisfied")
            return None  # Nothing to return if block not satisfied
        if len(block.tail.vertices) == 0:
            # print(block, " is trivial")
            return Node(block.head,self.head_to_cover[block.head],list()) # leaf node
        basis = self.block_to_basis[block]
        allBlocks = self.head_to_blocks[basis]
        blocks = [x for x in allBlocks if x < block]

        # print("Child BLocks for block ", block)
        # for bs in allBlocks:
        #     print(bs)

        children = list()
        for bs in blocks: 
            children.append(self.getDecomp(bs))
        
        return Node(basis,self.head_to_cover[basis],children)


    def getDecompRoot(self):
        if self.rootHead == None:
            return None  ## can't find decomp of whole graph if no root head

        allBlocks = self.head_to_blocks[self.rootHead]
        # print("Blocks of RootHead")
        # for bs in allBlocks:
        #     print(bs)


        blocks = [x for x in allBlocks if len(x.tail.vertices) != 0]

        # print("Non-Trivial Blocks of RootHead")
        # for bs in blocks:
        #     print(bs)

        children = list()
        for bs in blocks:
            children.append(self.getDecomp(bs))

        return Node(self.rootHead,self.head_to_cover[self.rootHead],children)

In [None]:
import itertools

def all_choose_k(S, k):
    from itertools import chain, combinations
    return chain(*(combinations(S, kp) for kp in range(1,k+1)))
#
def all_lambdas(E, k):
    yield set()
    for es in all_choose_k(E, k):
        # yield set.union(*es)
        yield es

# (over approximates) the bags produced by the LogK algorithm
def computesoftk(h, k):
    softk = list()
    lambdas = all_lambdas(h.E,k)
    for P in lambdas:
        obj1 = set()
        if len(P) == 1:
            obj1 = P[0].V
        elif len(P) > 1:
            obj1 =functools.reduce(lambda a,b: (a).union(b),map(lambda s : s.V,P))
        for C in h.separate(obj1, only_vertices=False):
            for L in lambdas:
                obj2 = set()
                if len(L) == 1:
                    obj2 = L[0].V
                elif len(L) > 1:
                    obj2 = functools.reduce(lambda a,b: (a).union(b),map(lambda s : s.V,L))
                B = set.intersection(C.V, obj2)
                if len(B) > 1 and (B, L) not in softk:
                    softk.append((B,L))
    return softk
    
# computes the blocks of a bag by computing its components w.r.t. h
def bag_to_blocks(h,pair):
    B = pair[0]
    L = pair[1]  
    blocks = list()
    for C in h.separate(B, only_vertices=True):
        blocks.append(Block(VertSet(B),L,VertSet(C)))
    blocks.append(Block(VertSet(B),L,VertSet(set())))  # adding trivial block too
    return blocks


# connected lambda a,h : a.connected(h)
def bag_to_blocks_constraint(h,constraint, pair): 
    # print("pair: ", pair)
    blocksPrefilter = bag_to_blocks(h,pair)
    return filter(constraint,blocksPrefilter)
    # return blocksPrefilter


# Same as  computeosftK, but returns directly the blocks
def computesoftkBlocks(h, k):
    out = list()
    listOfLists = map(partial(bag_to_blocks,h),computesoftk(h,k))
    for ll in listOfLists:
        for l in ll:
            out.append(l)
    return out


# Same as  computeosftK, but returns directly the blocks
def computesoftkBlocksConstraint(h, k,constraint):
    out = list()
    listOfLists = map(partial(bag_to_blocks_constraint,h,constraint),computesoftk(h,k))
    for ll in listOfLists:
        for l in ll:
            out.append(l)
    return out

def get_best_blocks(blocks, node_to_cost):
    best_weight = dict()
    best_cover = dict()

    for b in blocks:
        idx = b.index()
        weight = 1
        if idx in node_to_cost:
            weight = node_to_cost[idx]
        cover = b.cover
        if not b in best_weight:
            best_weight[b] = weight
            best_cover[b] = cover
        else:
            if weight < best_weight[b]:
                best_weight[b] = weight
                best_cover[b] = cover
    new_blocks = []
    for b, cover in best_cover.items():
        new_blocks.append(Block(b.head, cover, b.tail))
    return new_blocks

def covers_intersect(cover1, cover2):
    atts1 = set()
    for e in cover1:
        atts1.update(e.V)
    for e in cover2:
        if set.intersection(atts1, e.V) != set():
            return True
    return False

def get_connected_semijoins(covers):
    all_sjs = list(itertools.permutations(covers, 2))

    return [sj for sj in all_sjs if covers_intersect(sj[0], sj[1])]

    
def get_connected_semijoinsAll(covers):
    return list(itertools.permutations(covers, 2))
        

## INSERT LOCATION OF COMPILED JAR FILE IN VAR BELOW!!

In [None]:


##### INSERT JAR FILE NAME HERE !!! ==================
REWRITE_JAR = 'rewrite-assembly-0.1.0-SNAPSHOT.jar'
##### INSERT JAR FILE NAME HERE !!! ==================



from timeit import default_timer as timer

import threading


class Rewriting:
    def __init__(self, original, rewritten, features, time, drop_statements):
        self.original = original
        self.rewritten = rewritten
        self.features = features
        self.time = time
        self.drop_statements = drop_statements



class QueryRewriter:
    def __init__(self, host, database, user, password, port=5432, start_process=True):
        self.host = host
        self.database = database
        self.user = user
        self.password = password
        self.port = port
        self.jdbcString = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}'
        self.node_to_cost = dict()        
        self.node_to_cost_ideal = dict()

        if start_process:
            self.rewrite_process = subprocess.Popen(['java', '-jar', REWRITE_JAR], stdout=subprocess.PIPE)
            # Wait for the first line which is printed after the py4j server is started)
            line = self.rewrite_process.stdout.readline()
            # print(line)

        self.gateway = JavaGateway()

        self.rewriter = self.gateway.entry_point
        self.rewriter.connect(self.jdbcString, self.database, self.user, self.password)


    def rewrite_check_soft_numbers(self, query, k = 2, topn = 1):
        SizeSoftSet = 0
        
        self.rewriter.connect(self.jdbcString, self.database, self.user, self.password)
        self.rewriter.rewrite(query)

        output = json.loads(Path('output/output.json').read_text())
        drop_output = json.loads(Path('output/drop.json').read_text())

        result = Rewriting(query, output['rewritten_query'], output['features'], output['time'],
                           drop_output['rewritten_query'])

        hg = HyperGraph.fromHyperbench('output/hypergraph.txt')

        acyclic = output['acyclic']

        if acyclic == True:
            print("query is acyclic. done")
            for line in output["rewritten_query"]:
                print(line + ";")
            return "acyclic"
        

        blocksAll = computesoftkBlocks(hg,k)
        
        blocks = computesoftkBlocksConstraint(hg,k, lambda b: b.connected())

        allDict = dict()

        for b in blocksAll: 
            if b.head in allDict:
                continue
            allDict[b.head] = 0

        conDict = dict()

        for b in blocks: 
            if b.head in conDict:
                continue
            conDict[b.head] = 0
        
       
        return (len(allDict.keys()),len(conDict.keys()),len(hg.E))

    def rewrite(self, query, k = 2, topn = 1, quitAfterComputingTDs= False, Connected=False):):
        
        self.rewriter.connect(self.jdbcString, self.database, self.user, self.password)
        self.rewriter.rewrite(query)

        output = json.loads(Path('output/output.json').read_text())
        drop_output = json.loads(Path('output/drop.json').read_text())

        result = Rewriting(query, output['rewritten_query'], output['features'], output['time'],
                           drop_output['rewritten_query'])

        hg = HyperGraph.fromHyperbench('output/hypergraph.txt')

        acyclic = output['acyclic']

        if acyclic == True:
            print("query is acyclic. done")
            for line in output["rewritten_query"]:
                print(line + ";")
            return "acyclic"
        
        print('hg: ' + str(hg.markComplete() ) )
        hg = hg.markComplete()
        ctd = CTDOpt(hg)


        if Connected:
            blocks = computesoftkBlocksConstraint(hg,k, lambda b: b.connected())
        else:
            blocks = computesoftkBlocks(hg,k)
        

        # keep only distinct covers
        covers_dict = dict()
        for b in blocks:
            cover = list(map(lambda e: e.name, b.cover))
            index = ",".join(sorted(cover))
            # ignore single edges
            if not index in covers_dict:
                covers_dict[index] = b.cover

        # Get covers
        candidate_covers = [[e.name for e in cover] for cover in covers_dict.values() if len(cover) > 1]
        single_edge_covers = [[e.name] for e in hg.E if e not in candidate_covers]
        candidate_covers = candidate_covers + single_edge_covers

        #using SQL to extract cardinalities of bags and relations
        node_cost_stats = json.loads(self.rewriter.determineNodeWeightsJSON(json.dumps(candidate_covers)))
       
        for (cover, extracted_cardinality) in node_cost_stats:
            index = ",".join(sorted(cover))
            self.node_to_cost[index] = int(extracted_cardinality)
        print("node costs: " + str(self.node_to_cost))

        #using the older 'idealised' costs, based on EXPLAIN queries
        node_explain_plans_ideal = json.loads(self.rewriter.determineNodeWeightsJSONIdeal(json.dumps(candidate_covers)))
       
        for (cover, explain) in node_explain_plans_ideal:
            index = ",".join(sorted(cover))
            plan = json.loads(explain)[0]
            cost = plan['Plan']['Plan Rows']
            self.node_to_cost_ideal[index] = cost
        print("node costs ideal: " + str(self.node_to_cost_ideal))

        semijoins = get_connected_semijoins(covers_dict.values())
        # keep only edge names (no vertices)
        semijoins = [([e.name for e in cover1], [e.name for e in cover2]) for (cover1, cover2) in semijoins]

        print("Getting sj costs from Java")

        #extracting SemiJoin costs via EXPLAIN queries, used as part of the 'idealised' cost
        sj_explain_plans = json.loads(self.rewriter.determineSemijoinWeightsJSON(json.dumps(semijoins)))
        print("Received sj costs from Java")
        
        sj_to_cost = dict()
        for (c1, c2, explain) in sj_explain_plans:
            # print("explain ",explain)
            index1 = ",".join(sorted(c1))
            index2 = ",".join(sorted(c2))
            index = index1 + "-" + index2
            plan = json.loads(explain)[0]
            sj_cost = plan['Plan']['Total Cost']
            node_costs = self.node_to_cost_ideal.get(index1, 0) + self.node_to_cost_ideal.get(index2, 0)
            sj_to_cost[index] = max(sj_cost - node_costs, 1)
        print("sj costs: " + str(sj_to_cost))

        #print("\n".join(map(lambda b: str(b), blocks)))
        ctd.add_weights(self.node_to_cost)
        ctd.add_weights_ideal(self.node_to_cost_ideal)
        ctd.add_sj_weights(sj_to_cost)
        blocks = get_best_blocks(blocks, self.node_to_cost)
        
        
        # print("\n".join(map(lambda b: str(b), blocks)))
        # The weights have to be added before the blocks because the join costs for trivially satisfied blocks
        # are set in addBlock
        for b in blocks:
            ctd.addBlock(b)
        
        
        start_time_TD_comp = timer()
        res = ctd.minimize_weights(topn)        
        elapsed_time_TD_comp = timer() - start_time_TD_comp

        print("Time it took to compute TDs: ", elapsed_time_TD_comp,  "  len res set" , len(res))

        
        # print("Result: ",res)

        rewritings = []
        drops = []
        tds = []

        if quitAfterComputingTDs:
            return (rewritings, drops, tds)

        for td in res:
            if k > 1 and len(td.cover) == 1:
                continue #skip TD whose root has only a single edge cover
            
            # print("---------------\nTD as JSON\n------------------\n")
            # print(json.dumps(td, cls=NodeEncoder))

            # print("original TD ", td)
            td.removeMarkers()
            # print("Cleaned Up TD: ",td)
            
            output = self.rewriter.rewriteCyclicJSON(json.dumps(td, cls=NodeEncoder))
            output = json.loads(Path('output/output.json').read_text())
            # print("-------------\n OUTPUTTT \n -----------")
            # print(output)
            rewritings.append(output["rewritten_query"])
            # for line in output["rewritten_query"]:
            #     print(line + ";")
            drop_output = json.loads(Path('output/drop.json').read_text())
            drops.append(drop_output["rewritten_query"])
            tds.append(td)

        return (rewritings, drops, tds)

    def run_rewriting(self, rewriting, drop_statements):
            
        conn = psycopg2.connect(
            host=self.host,
            database = self.database,
            user = self.user,
            password = self.password
        )
        cur = conn.cursor()

                    
        cur.execute("SET statement_timeout = %s", ('300000',))  # timeout in milliseconds


        # start_time = timer()

        print( "entering try block ") 
        
        try:           
            for query in rewriting:
                cur.execute(query)
        except Exception as e:
            print("here")
            print(e)
            print("start commit")
            conn.commit()
            print("end commit")


        print("executing drop stuff")
        for drop in reversed(drop_statements):
            cur.execute(drop)
        print("fionished drop stuff")
        
        cur.close()
        conn.close()
    # return elapsed_time

    
    def run_query(self, query):
        conn = psycopg2.connect(
            host=self.host,
            database = self.database,
            user = self.user,
            password = self.password
        )
        cur = conn.cursor()

            
        cur.execute("SET statement_timeout = %s", ('300000',))  # timeout in milliseconds
        
        import time

        start_time = timer()            
    
        try:
            cur.execute(query)
        except Exception as e:
            print("here")
            print(e)
            print("start commit")
            conn.commit()
            print("end commit")

        elapsed_time = timer() - start_time

        cur.close()
        conn.close()
        return elapsed_time
        
    def close(self):
        self.rewrite_process.kill()



    

# TODO: use this one for the unoptmised version

class QueryRewriterUnOpt:
    def __init__(self, host, database, user, password, port=5432, start_process=True):
        self.host = host
        self.database = database
        self.user = user
        self.password = password
        self.port = port
        self.jdbcString = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}'
        self.node_to_cost = dict()        
        self.node_to_cost_ideal = dict()

        if start_process:
            self.rewrite_process = subprocess.Popen(['java', '-jar', REWRITE_JAR], stdout=subprocess.PIPE)
            # Wait for the first line which is printed after the py4j server is started)
            line = self.rewrite_process.stdout.readline()
            # print(line)

        self.gateway = JavaGateway()

        self.rewriter = self.gateway.entry_point
        self.rewriter.connect(self.jdbcString, self.database, self.user, self.password)


    def rewrite_check_soft_numbers(self, query, k = 2, topn = 1):
        SizeSoftSet = 0
        
        self.rewriter.connect(self.jdbcString, self.database, self.user, self.password)
        self.rewriter.rewrite(query)

        output = json.loads(Path('output/output.json').read_text())
        drop_output = json.loads(Path('output/drop.json').read_text())

        result = Rewriting(query, output['rewritten_query'], output['features'], output['time'],
                           drop_output['rewritten_query'])

        hg = HyperGraph.fromHyperbench('output/hypergraph.txt')

        acyclic = output['acyclic']

        if acyclic == True:
            print("query is acyclic. done")
            for line in output["rewritten_query"]:
                print(line + ";")
            return "acyclic"
        
        # print('hg: ' + str(hg.markComplete() ) )
        # hg = hg.markComplete()
        # ctd = CTDOpt(hg)
        # print('ctd: ' + str(ctd))
        blocksAll = computesoftkBlocks(hg,k)
        #blocks = computesoftkBlocksConnected(hg,k)
        blocks = computesoftkBlocksConstraint(hg,k, lambda b: b.connected())

        allDict = dict()

        for b in blocksAll: 
            if b.head in allDict:
                continue
            allDict[b.head] = 0

        conDict = dict()

        for b in blocks: 
            if b.head in conDict:
                continue
            conDict[b.head] = 0
        
       
        return (len(allDict.keys()),len(conDict.keys()),len(hg.E))

    def rewrite(self, query, k = 2, topn = 1, Connected=False):
        
        self.rewriter.connect(self.jdbcString, self.database, self.user, self.password)
        self.rewriter.rewrite(query)

        output = json.loads(Path('output/output.json').read_text())
        drop_output = json.loads(Path('output/drop.json').read_text())

        result = Rewriting(query, output['rewritten_query'], output['features'], output['time'],
                           drop_output['rewritten_query'])

        hg = HyperGraph.fromHyperbench('output/hypergraph.txt')

        acyclic = output['acyclic']

        if acyclic == True:
            print("query is acyclic. done")
            for line in output["rewritten_query"]:
                print(line + ";")
            return "acyclic"
        
        print('hg: ' + str(hg.markComplete() ) )
        hg = hg.markComplete()
        ctd = CTDOpt(hg)

        if Connected:
            blocks = computesoftkBlocksConstraint(hg,k, lambda b: b.connected())
        else:
            blocks = computesoftkBlocks(hg,k)
        
        # keep only distinct covers
        covers_dict = dict()
        for b in blocks:
            cover = list(map(lambda e: e.name, b.cover))
            index = ",".join(sorted(cover))
            # ignore single edges
            if not index in covers_dict:
                covers_dict[index] = b.cover

        # Get non-single-edge covers
        candidate_covers = [[e.name for e in cover] for cover in covers_dict.values() if len(cover) > 1]
        single_edge_covers = [[e.name] for e in hg.E if e not in candidate_covers]
        candidate_covers = candidate_covers + single_edge_covers

        for cover in candidate_covers:
            index = ",".join(sorted(cover))
            self.node_to_cost[index] = 0
            self.node_to_cost_ideal[index] = 0
            
        
        
        # print("candidate covers", candidate_covers)
        # node_explain_plans = json.loads(self.rewriter.determineNodeWeightsJSON(json.dumps(candidate_covers)))
       
        # for (cover, explain) in node_explain_plans:
        #     index = ",".join(sorted(cover))
        #     # plan = json.loads(explain)[0]
        #     # cost = plan['Plan']['Plan Rows']
        #     cost = int(explain)
        #     self.node_to_cost[index] = cost
        # print("node costs: " + str(self.node_to_cost))

        # node_explain_plans_ideal = json.loads(self.rewriter.determineNodeWeightsJSONIdeal(json.dumps(candidate_covers)))
       
        # for (cover, explain) in node_explain_plans_ideal:
        #     index = ",".join(sorted(cover))
        #     plan = json.loads(explain)[0]
        #     cost = plan['Plan']['Plan Rows']
        #     self.node_to_cost_ideal[index] = cost
        # print("node costs ideal: " + str(self.node_to_cost_ideal))

        semijoins = get_connected_semijoinsAll(covers_dict.values())
        # keep only edge names (no vertices)
        semijoins = [([e.name for e in cover1], [e.name for e in cover2]) for (cover1, cover2) in semijoins]

        # print("Getting sj costs from Java")
        # sj_explain_plans = json.loads(self.rewriter.determineSemijoinWeightsJSON(json.dumps(semijoins)))
        # print("Received sj costs from Java")

        
        # sj_to_cost = dict()
        # for (c1, c2, explain) in sj_explain_plans:
        sj_to_cost = dict()
        for (c1, c2) in semijoins:
            # print("explain ",explain)
            index1 = ",".join(sorted(c1))
            index2 = ",".join(sorted(c2))
            index = index1 + "-" + index2
            # plan = json.loads(explain)[0]
            # sj_cost = plan['Plan']['Total Cost']
            # node_costs = self.node_to_cost_ideal.get(index1, 0) + self.node_to_cost_ideal.get(index2, 0)
            # sj_to_cost[index] = max(sj_cost - node_costs, 1)
            sj_to_cost[index] =0
        # print("sj costs: " + str(sj_to_cost))

        #print("\n".join(map(lambda b: str(b), blocks)))
        ctd.add_weights(self.node_to_cost)
        ctd.add_weights_ideal(self.node_to_cost_ideal)
        ctd.add_sj_weights(sj_to_cost)
        blocks = get_best_blocks(blocks, self.node_to_cost)
        

        print("start to ad blocks")
        
        # print("\n".join(map(lambda b: str(b), blocks)))
        # The weights have to be added before the blocks because the join costs for trivially satisfied blocks
        # are set in addBlock
        for b in blocks:
            ctd.addBlock(b)

        print("added all blocks ", len(blocks))
        
        
        res = ctd.minimize_weights(topn)
        # print("Result: ",res)

        rewritings = []
        drops = []
        tds = []

        for td in res:
            # print("---------------\nTD as JSON\n------------------\n")
            # print(json.dumps(td, cls=NodeEncoder))

            # print("original TD ", td)
            td.removeMarkers()
            # print("Cleaned Up TD: ",td)
            
            output = self.rewriter.rewriteCyclicJSON(json.dumps(td, cls=NodeEncoder))
            output = json.loads(Path('output/output.json').read_text())
            # print("-------------\n OUTPUTTT \n -----------")
            # print(output)
            rewritings.append(output["rewritten_query"])
            # for line in output["rewritten_query"]:
            #     print(line + ";")
            drop_output = json.loads(Path('output/drop.json').read_text())
            drops.append(drop_output["rewritten_query"])
            tds.append(td)

        return (rewritings, drops, tds)

    def run_rewriting(self, rewriting, drop_statements):
            
        conn = psycopg2.connect(
            host=self.host,
            database = self.database,
            user = self.user,
            password = self.password
        )
        cur = conn.cursor()

                    
        cur.execute("SET statement_timeout = %s", ('300000',))  # timeout in milliseconds


        # start_time = timer()

        print( "entering try block ") 
        
        try:           
            for query in rewriting:
                cur.execute(query)
        except Exception as e:
            print("here")
            print(e)
            print("start commit")
            conn.commit()
            print("end commit")


        print("executing drop stuff")
        for drop in reversed(drop_statements):
            cur.execute(drop)
        print("fionished drop stuff")
        
        cur.close()
        conn.close()
    # return elapsed_time

    
    def run_query(self, query):
        conn = psycopg2.connect(
            host=self.host,
            database = self.database,
            user = self.user,
            password = self.password
        )
        cur = conn.cursor()

            
        cur.execute("SET statement_timeout = %s", ('300000',))  # timeout in milliseconds
        
        import time

        start_time = timer()            
    
        try:
            cur.execute(query)
        except Exception as e:
            print("here")
            print(e)
            print("start commit")
            conn.commit()
            print("end commit")

        elapsed_time = timer() - start_time

        cur.close()
        conn.close()
        return elapsed_time
        
    def close(self):
        self.rewrite_process.kill()

In [None]:
import pandas as pd
import statistics
from datetime import datetime

def run_top_rewritings(rewriter, query, k = 2, topn = 1,name=""):
    (rewritings, drops, tds) = rewriter.rewrite(query, k, topn)

    results = []
    for i, rewriting in enumerate(rewritings):
        drop_statements = drops[i]
        td = tds[i]
        print("running query", i)

       
        times = []
        for x in range(0, 3):  
            start_time = timer()
            rewriter.run_rewriting(rewriting, drop_statements)
            elapsed_time = timer() - start_time
            times.append(elapsed_time)
        
        print("Times list: ",times)
         
        rewriterJoin = ";\n".join(rewriting)
        results.append(["TD" + str(i),statistics.mean(times), td.weight,td.weight_ideal,td.toStringCost(1,rewriter.node_to_cost),rewriterJoin])
        # print(results)
    
    df = pd.DataFrame(results, columns = ['td_name','runtime', 'cost', 'cost_ideal', 'decomp', 'rewriting' ])
    df.to_csv('results'+name+str(datetime.now().strftime("%d%m%Y%H%M%S"))+'.csv')
    return df


def run_top_rewritings_soft_nums(rewriter, query, k = 2, topn = 1,name=""):
    return rewriter.rewrite_check_soft_numbers(query,k,topn)


def run_top_rewritings_check_TD_comp_time(rewriter, query, k = 2, topn = 1,name=""):
    rewriter.rewrite(query, k, topn, quitAfterComputingTDs = True)

        

Below begin the tests for the 6 queries. For each, there are two cells: 
* The first cell runs the experiment. There are a number of options to select the kind of experiment to run
* The second cell visualises the results, plotting the runtime of each TD against its cost and the baseline

## Q_TPC-DS

In [None]:
# CELL TPS-DS Test

from IPython.display import display, HTML

if 'rewriter' in locals():
    rewriter.close()


%reload_ext line_profiler

query = """
SELECT min(ws_bill_customer_sk)
FROM   web_sales, 
       customer, 
       customer_address,
       catalog_sales,
       warehouse
WHERE  ws_bill_customer_sk = c_customer_sk 
       AND ca_address_sk =  c_current_addr_sk 
       AND c_current_addr_sk = cs_bill_addr_sk
       AND cs_warehouse_sk = w_warehouse_sk
       AND  w_warehouse_sq_ft = ws_quantity

"""


noOpt = True


if noOpt:
    rewriter = QueryRewriterUnOpt('localhost', 'tpcds', 'postgres', 'postgres')
else:        
    rewriter = QueryRewriter('localhost', 'tpcds', 'postgres', 'postgres')
    


size_check = False
Connected = True
OnlyComputeTDs = False


if size_check:
    sizeAll, sizeConnected,hgSize = run_top_rewritings_soft_nums(rewriter, query, k = 2, topn = 10,name="TPCDS-sortedTDs-filteredRoot")

    print("HG size: ", hgSize)
    print("Size All:", sizeAll)
    print("Size Connected:", sizeConnected)
    
else:
    if OnlyComputeTDs:
        run_top_rewritings_check_TD_comp_time(rewriter, query, k = 2, topn = 10,name="TPCDS-sortedTDs-filteredRoot")
    else:   
        start_time = timer()
        df_tpds = run_top_rewritings(rewriter, query, k = 2, topn = 10,name="TPCDS-unOpt-connected",Connected)
        run_time = timer() - start_time
        print("run time Total: ", run_time)
        
        print("Comparison to baseline")
        time = rewriter.run_query(query)
        print("Run time: ", time)



rewriter.close()



In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np



fig = plt.figure(figsize=(8,5))

ax1 = fig.add_subplot(121)
ax1.scatter(x=df_tpcds['cost'],y=df_tpcds['runtime'])


df_tpcds_proj = df_tpcds.loc[:,['cost', 'runtime']]

ax2 = fig.add_subplot(122)
font_size=12
bbox=[0, 0, 1.5, 1]
ax2.axis('off')
mpl_table = ax2.table(cellText = df_tpcds_proj.values, rowLabels = df_tpcds_proj.index, bbox=bbox, colLabels=df_tpcds_proj.columns)
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)



df_tpcds_proj2 = df_tpcds.loc[:,['cost', 'cost_ideal', 'runtime']]

df_tpcds_proj2['baseline']=time

print(df_tpcds_proj2.to_csv(index=False,float_format='%.2f'))


## Q_HTO

In [None]:
# CELL HETIO Test


if 'rewriter' in locals():
    rewriter.close()

%reload_ext line_profiler

query = """
select min(hetio45173_0.s)
from   hetio45173 hetio45173_0, hetio45173 hetio45173_1, 
       hetio45160 hetio45160_2, hetio45160 hetio45160_3, 
       hetio45160 hetio45160_4, hetio45159 hetio45159_5, 
       hetio45159 hetio45159_6 
where  hetio45173_0.s = hetio45173_1.s and hetio45173_0.d = hetio45160_2.s and 
       hetio45173_1.d = hetio45160_3.s and hetio45160_2.d = hetio45160_3.d and 
       hetio45160_3.d = hetio45160_4.s and hetio45160_4.s = hetio45159_5.s and 
       hetio45160_4.d = hetio45159_6.s and hetio45159_5.d = hetio45159_6.d

"""

noOpt = True


if noOpt:
    rewriter = QueryRewriterUnOpt('localhost', 'hetio', 'postgres', 'postgres')
else:        
    rewriter = QueryRewriter('localhost', 'hetio', 'postgres', 'postgres')
    



size_check = False
Connected = True
OnlyComputeTDs = False



if size_check:
    sizeAll, sizeConnected,hgSize = run_top_rewritings_soft_nums(rewriter, query, k = 2, topn = 10,name="HETIO-sortedTDs-filteredRoot")

    print("HG size: ", hgSize)
    print("Size All:", sizeAll)
    print("Size Connected:", sizeConnected)
    
else:
    if OnlyComputeTDs:
        run_top_rewritings_check_TD_comp_time(rewriter, query, k = 2, topn = 10,name="HETIO-sortedTDs-filteredRoot")
    else:   
        start_time = timer()
        df_hetio1 = run_top_rewritings(rewriter, query, k = 2, topn = 10,name="HETIO-unOpt-connected", Connected)
        run_time = timer() - start_time
        print("run time Total: ", run_time)
        
                
        print("Comparison to baseline")
        time = rewriter.run_query(query)
        print("Run time: ", time)




rewriter.close()


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np



fig = plt.figure(figsize=(8,5))

ax1 = fig.add_subplot(121)
ax1.scatter(x=df_hetio1['cost'],y=df_hetio1['runtime'])


df_hetio1_proj = df_hetio1.loc[:,['cost', 'runtime']]

ax2 = fig.add_subplot(122)
font_size=12
bbox=[0, 0, 1.5, 1]
ax2.axis('off')
mpl_table = ax2.table(cellText = df_hetio1_proj.values, rowLabels = df_hetio1_proj.index, bbox=bbox, colLabels=df_hetio1_proj.columns)
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)



df_hetio1_proj2 = df_hetio1.loc[:,['cost', 'cost_ideal', 'runtime']]

# df_hetio1_proj2['baseline']=time

print(df_hetio1_proj2.to_csv(index=False,float_format='%.2f'))


## Q2_HTO

In [None]:
# CELL HETIO2 Test


if 'rewriter' in locals():
    rewriter.close()

%reload_ext line_profiler

query = """
select      max(hetio45160.d) 
from        hetio45173 hetio45173_0, hetio45173 hetio45173_1, hetio45173 hetio45173_2, 
            hetio45173 hetio45173_3, hetio45160, hetio45176 hetio45176_5, hetio45176 
            hetio45176_6 
where       hetio45173_0.s = hetio45173_1.s and hetio45173_0.d = hetio45173_2.s and 
            hetio45173_1.d = hetio45173_3.s and hetio45173_2.d = hetio45173_3.d and 
            hetio45173_3.d = hetio45160.s and hetio45160.s = hetio45176_5.s and 
            hetio45160.d = hetio45176_6.s and hetio45176_5.d = hetio45176_6.d
"""



noOpt = True


if noOpt:
    #rewriter = QueryRewriter('postgres', 'tpch', 'tpch', 'tpch', start_process=False)
    rewriter = QueryRewriterUnOpt('localhost', 'hetio', 'postgres', 'postgres')
else:        
    #rewriter = QueryRewriter('postgres', 'tpch', 'tpch', 'tpch', start_process=False)
    rewriter = QueryRewriter('localhost', 'hetio', 'postgres', 'postgres')
    


size_check = False
Connected = True
OnlyComputeTDs = False



if size_check:
    sizeAll, sizeConnected,hgSize = run_top_rewritings_soft_nums(rewriter, query, k = 2, topn = 10,name="HETIO2-sortedTDs-fiterRoot")

    print("HG size: ", hgSize)
    print("Size All:", sizeAll)
    print("Size Connected:", sizeConnected)
    
else:
    if OnlyComputeTDs:
        run_top_rewritings_check_TD_comp_time(rewriter, query, k = 2, topn = 10,name="HETIO2-sortedTDs-fiterRoot")
    else:   
        start_time = timer()
        df_hetio2 = run_top_rewritings(rewriter, query, k = 2, topn = 10,name="HETIO2-unOpt-connected", Connected)
        run_time = timer() - start_time
        print("run time Total: ", run_time)

                
        print("Comparison to baseline")
        time = rewriter.run_query(query)
        print("Run time: ", time)



rewriter.close()




In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np



fig = plt.figure(figsize=(8,5))

ax1 = fig.add_subplot(121)
ax1.scatter(x=df_hetio2['cost'],y=df_hetio2['runtime'])


df_hetio2_proj = df_hetio2.loc[:,['cost', 'runtime']]

ax2 = fig.add_subplot(122)
font_size=12
bbox=[0, 0, 1.5, 1]
ax2.axis('off')
mpl_table = ax2.table(cellText = df_hetio2_proj.values, rowLabels = df_hetio2_proj.index, bbox=bbox, colLabels=df_hetio2_proj.columns)
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)



df_hetio2_proj2 = df_hetio2.loc[:,['cost', 'cost_ideal', 'runtime']]

df_hetio2_proj2['baseline']=time

print(df_hetio2_proj2.to_csv(index=False,float_format='%.2f'))


## Q3_HTO

In [None]:
# CELL hetio3 Test


if 'rewriter' in locals():
    rewriter.close()

%reload_ext line_profiler

query = """

select  min(hetio45173_2.d)
from    hetio45173 hetio45173_0, hetio45173 hetio45173_1, hetio45173 
        hetio45173_2, hetio45173 hetio45173_3 
where   hetio45173_0.s = hetio45173_1.s and hetio45173_0.d = hetio45173_2.s 
        and hetio45173_1.d = hetio45173_3.d and hetio45173_2.d = hetio45173_3.s

"""

noOpt = True


if noOpt:
    rewriter = QueryRewriterUnOpt('localhost', 'hetio', 'postgres', 'postgres')
else:        
    rewriter = QueryRewriter('localhost', 'hetio', 'postgres', 'postgres')
    


size_check = False
Connected = True
OnlyComputeTDs = False



if size_check:
    sizeAll, sizeConnected,hgSize = run_top_rewritings_soft_nums(rewriter, query, k = 2, topn = 10,name="HETIO3-sortedTDs-fiterRoot")

    print("HG size: ", hgSize)
    print("Size All:", sizeAll)
    print("Size Connected:", sizeConnected)
    
else:
    if OnlyComputeTDs:
        run_top_rewritings_check_TD_comp_time(rewriter, query, k = 2, topn = 10,name="HETIO3-sortedTDs-fiterRoot")
    else:   
        start_time = timer()
        df_hetio3 = run_top_rewritings(rewriter, query, k = 2, topn = 10,name="HETIO3-unOpt-connected",Connected)
        run_time = timer() - start_time
        print("run time Total: ", run_time)

                
        print("Comparison to baseline")
        time = rewriter.run_query(query)
        print("Run time: ", time)



rewriter.close()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np



fig = plt.figure(figsize=(8,5))

ax1 = fig.add_subplot(121)
ax1.scatter(x=df_hetio3['cost'],y=df_hetio3['runtime'])


df_hetio3_proj = df_hetio3.loc[:,['cost', 'runtime']]

ax2 = fig.add_subplot(122)
font_size=12
bbox=[0, 0, 1.5, 1]
ax2.axis('off')
mpl_table = ax2.table(cellText = df_hetio3_proj.values, rowLabels = df_hetio3_proj.index, bbox=bbox, colLabels=df_hetio3_proj.columns)
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)



df_hetio3_proj2 = df_hetio3.loc[:,['cost', 'cost_ideal', 'runtime']]

df_hetio3_proj2['baseline']=time

print(df_hetio3_proj2.to_csv(index=False,float_format='%.2f'))


## Q4_HTO

In [None]:
# CELL hetio4 Test


if 'rewriter' in locals():
    rewriter.close()

%reload_ext line_profiler

query = """

select  min(hetio45160_0.s) 
from    hetio45160 hetio45160_0, hetio45160 hetio45160_1, 
        hetio45177, hetio45160 hetio45160_3, hetio45159 
        hetio45159_4, hetio45159 hetio45159_5 
where   hetio45160_0.s = hetio45160_1.s and hetio45160_0.d = hetio45177.s 
        and hetio45160_1.d = hetio45177.d and hetio45177.d = hetio45160_3.s 
        and hetio45160_3.s = hetio45159_4.s and hetio45160_3.d = hetio45159_5.s 
        and hetio45159_4.d = hetio45159_5.d
"""
noOpt = True


if noOpt:
    rewriter = QueryRewriterUnOpt('localhost', 'hetio', 'postgres', 'postgres')
else:        
    rewriter = QueryRewriter('localhost', 'hetio', 'postgres', 'postgres')
    


size_check = False
Connected = True
OnlyComputeTDs = False


if size_check:
    sizeAll, sizeConnected,hgSize = run_top_rewritings_soft_nums(rewriter, query, k = 2, topn = 10,name="HETIO4-sortedTDs-fiterRoot")

    print("HG size: ", hgSize)
    print("Size All:", sizeAll)
    print("Size Connected:", sizeConnected)
    
else:
    if OnlyComputeTDs:
        run_top_rewritings_check_TD_comp_time(rewriter, query, k = 2, topn = 10,name="HETIO4-sortedTDs-fiterRoot")
    else:   
        start_time = timer()
        df_hetio4 = run_top_rewritings(rewriter, query, k = 2, topn = 10,name="HETIO4-unOpt-connected", Connected)
        run_time = timer() - start_time
        print("run time Total: ", run_time)

                
        print("Comparison to baseline")
        time = rewriter.run_query(query)
        print("Run time: ", time)



rewriter.close()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np



fig = plt.figure(figsize=(8,5))

ax1 = fig.add_subplot(121)
ax1.scatter(x=df_hetio4['cost'],y=df_hetio4['runtime'])


df_hetio4_proj = df_hetio4.loc[:,['cost', 'runtime']]

ax2 = fig.add_subplot(122)
font_size=12
bbox=[0, 0, 1.5, 1]
ax2.axis('off')
mpl_table = ax2.table(cellText = df_hetio4_proj.values, rowLabels = df_hetio4_proj.index, bbox=bbox, colLabels=df_hetio4_proj.columns)
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)



df_hetio4_proj2 = df_hetio4.loc[:,['cost', 'cost_ideal', 'runtime']]


df_hetio4_proj2['baseline']=time

print(df_hetio4_proj2.to_csv(index=False,float_format='%.2f'))


## Q_LSQB

In [None]:
# CELL LSQB Test
#

if 'rewriter' in locals():
    rewriter.close()


%reload_ext line_profiler

import time

query = """
SELECT MIN(pkp1.Person1Id)
FROM City AS CityA
JOIN City AS CityB
  ON CityB.isPartOf_CountryId = CityA.isPartOf_CountryId
JOIN City AS CityC
  ON CityC.isPartOf_CountryId = CityA.isPartOf_CountryId
JOIN Person AS PersonA
  ON PersonA.isLocatedIn_CityId = CityA.CityId
JOIN Person AS PersonB
  ON PersonB.isLocatedIn_CityId = CityB.CityId
JOIN Person_knows_Person AS pkp1
  ON pkp1.Person1Id = PersonA.PersonId
 AND pkp1.Person2Id = PersonB.PersonId
"""

noOpt = True


if noOpt:
    #rewriter = QueryRewriter('postgres', 'tpch', 'tpch', 'tpch', start_process=False)
    rewriter = QueryRewriterUnOpt('localhost', 'lsqb', 'postgres', 'postgres')
else:        
    #rewriter = QueryRewriter('postgres', 'tpch', 'tpch', 'tpch', start_process=False)
    rewriter = QueryRewriter('localhost', 'lsqb', 'postgres', 'postgres')
    


size_check = False
Connected = True
OnlyComputeTDs = False


if size_check:
    sizeAll, sizeConnected,hgSize = run_top_rewritings_soft_nums(rewriter, query, k = 3, topn = 10,name="LSQB-sortedTDs-fiterRoot")

    print("HG size: ", hgSize)
    print("Size All:", sizeAll)
    print("Size Connected:", sizeConnected)
    
else:
    if OnlyComputeTDs:
        run_top_rewritings_check_TD_comp_time(rewriter, query, k = 3, topn = 10,name="LSQB-sortedTDs-fiterRoot")
    else:   
        start_time = timer()
        df_lsqb = run_top_rewritings(rewriter, query, k = 3, topn = 10,name="LSQB-unOpt-connected", Connected)
        run_time = timer() - start_time
        print("run time Total: ", run_time)

                
        print("Comparison to baseline")
        time = rewriter.run_query(query)
        print("Run time: ", time)



rewriter.close()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np



fig = plt.figure(figsize=(8,5))

ax1 = fig.add_subplot(121)
ax1.scatter(x=df_lsqb['cost'],y=df_lsqb['runtime'])


df_lsqb_proj = df_lsqb.loc[:,['cost', 'runtime']]

ax2 = fig.add_subplot(122)
font_size=12
bbox=[0, 0, 1.5, 1]
ax2.axis('off')
mpl_table = ax2.table(cellText = df_lsqb_proj.values, rowLabels = df_lsqb_proj.index, bbox=bbox, colLabels=df_lsqb_proj.columns)
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)



df_lsqb_proj2 = df_lsqb.loc[:,['cost', 'cost_ideal', 'runtime']]

df_lsqb_proj2['baseline']=time

print(df_lsqb_proj2.to_csv(index=False,float_format='%.2f'))
