# 3 - Overcoming SOTA Performance On IMDB With JOB-Light
*By Marcus Schwarting and Andronicus Samsundar Rajasukumar*

In this notebook, we will:
- Introduce the Kipf et. al. model that we wish to improve upon
- Show various implementations of featurization routines, and discuss their pros and cons
- Discuss changes to the Kipf implementation that yielded overall improvements in accuracy and training time

The performance benchmark that we wish to beat, as described in the literature on the JOB-light test query set on the IMDB dataset, is as follows:

| Metric | Value |
| ---- | ---- |
|Median | 3.82|
|90th Percentile| 78.4|
|95th Percentile|362|
|Max|1110|
|Mean|57.9|

In [2]:
#MODIFIED VERSION OF KIPF ET AL CODE (originally from https://github.com/andreaskipf/learnedcardinalities)#
import time
import os

import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader

from mscn.util import *
from mscn.data import get_train_datasets, load_data, make_dataset
from mscn.model import SetConv


### Introducing Kipf MSCN Model
The authors achieve the above benchmark performance by using a multi-set convolutional network (MSCN). We have re-implemented their methods with some changes that have marginally improved on the state of the art. Below we re-use some of their code infrastructure and point out important changes where they are applicable.

In [3]:
def unnormalize_torch(vals, min_val, max_val):
    #Read from "imdb_max_min.csv"
    vals = (vals * (max_val - min_val)) + min_val
    return torch.exp(vals)


def qerror_loss(preds, targets, min_val, max_val):
    #Returns Q-error, can also return MAE as desired.
    qerror = []
    preds = unnormalize_torch(preds, min_val, max_val)
    targets = unnormalize_torch(targets, min_val, max_val)

    for i in range(len(targets)):
        if (preds[i] > targets[i]).cpu().data.numpy()[0]:
            qerror.append(preds[i] / targets[i])
        else:
            qerror.append(targets[i] / preds[i])
    return torch.mean(torch.cat(qerror))


def predict(model, data_loader):
    #The workhorse. Evaluates the final model and runs predictions.
    preds = []
    t_total = 0.

    model.eval()
    for batch_idx, data_batch in enumerate(data_loader):

        samples, predicates, joins, targets, sample_masks, predicate_masks, join_masks = data_batch
        t = time.time()
        outputs = model(samples, predicates, joins, sample_masks, predicate_masks, join_masks)
        t_total += time.time() - t

        for i in range(outputs.data.shape[0]):
            preds.append(outputs.data[i])

    return preds, t_total

def print_qerror(preds_unnorm, labels_unnorm):
    qerror = []
    for i in range(len(preds_unnorm)):
        if preds_unnorm[i] > float(labels_unnorm[i]):
            qerror.append(preds_unnorm[i] / float(labels_unnorm[i]))
        else:
            qerror.append(float(labels_unnorm[i]) / float(preds_unnorm[i]))

    print(f"Median: {np.median(qerror)}")
    print(f"90th percentile: {np.percentile(qerror, 90)}")
    print(f"95th percentile: {np.percentile(qerror, 95)}")
    print(f"99th percentile: {np.percentile(qerror, 99)}")
    print(f"Max: {np.max(qerror)}")
    print(f"Mean: {np.mean(qerror)}")

In [15]:
def train_and_predict(workload_name, num_queries=1000, num_epochs=100, \
                      batch_size=100, hid_units=256, verbose=False,write=False):
    # Load training and validation data
    num_materialized_samples = 1000
    dicts, column_min_max_vals, min_val, max_val, labels_train, \
    labels_test, max_num_joins, max_num_predicates, \
    train_data, test_data = get_train_datasets('all_train_queries.sql', num_queries, \
                                               num_materialized_samples)
    table2vec, column2vec, op2vec, join2vec = dicts

    # Train model
    sample_feats = len(table2vec) + num_materialized_samples
    predicate_feats = len(column2vec) + len(op2vec) + 1
    join_feats = len(join2vec)

    model = SetConv(sample_feats, predicate_feats, join_feats, hid_units)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005) #lr=0.001 originally
    
    train_data_loader = DataLoader(train_data, batch_size=batch_size)
    test_data_loader = DataLoader(test_data, batch_size=batch_size)

    model.train()
    for epoch in range(num_epochs):
        loss_total = 0.

        for batch_idx, data_batch in enumerate(train_data_loader):

            samples, predicates, joins, targets, sample_masks, predicate_masks, join_masks = data_batch
            
            optimizer.zero_grad()
            outputs = model(samples, predicates, joins, sample_masks, predicate_masks, join_masks)
            loss = qerror_loss(outputs, targets.float(), min_val, max_val)
            loss_total += loss.item()
            loss.backward()
            optimizer.step()
        if verbose:
            print("Epoch {}, loss: {}".format(epoch, loss_total / len(train_data_loader)))

    # Get final training and validation set predictions
    preds_train, t_total = predict(model, train_data_loader)
    if verbose:
        print("Prediction time per training sample: {}".format(t_total / len(labels_train) * 1000))

    preds_test, t_total = predict(model, test_data_loader)
    if verbose:
        print("Prediction time per validation sample: {}".format(t_total / len(labels_test) * 1000))

    # Unnormalize
    preds_train_unnorm = unnormalize_labels(preds_train, min_val, max_val)
    labels_train_unnorm = unnormalize_labels(labels_train, min_val, max_val)

    preds_test_unnorm = unnormalize_labels(preds_test, min_val, max_val)
    labels_test_unnorm = unnormalize_labels(labels_test, min_val, max_val)

    # Print metrics
    if verbose:
        print("\nQ-Error training set:")
        print_qerror(preds_train_unnorm, labels_train_unnorm)
        print("\nQ-Error validation set:")
        print_qerror(preds_test_unnorm, labels_test_unnorm)
        print("")

    # Load test data
    file_name = "workloads/" + workload_name
    joins, predicates, tables, samples, label = load_data(file_name, num_materialized_samples)

    # Get feature encoding and proper normalization
    samples_test = encode_samples(tables, samples, table2vec)
    predicates_test, joins_test = encode_data(predicates, joins, column_min_max_vals, column2vec, op2vec, join2vec)
    labels_test, _, _ = normalize_labels(label, min_val, max_val)
    if verbose:
        print(f"Number of test samples: {len(labels_test)}")

    max_num_predicates = max([len(p) for p in predicates_test])
    max_num_joins = max([len(j) for j in joins_test])

    # Get test set predictions
    test_data = make_dataset(samples_test, predicates_test, joins_test, labels_test, max_num_joins, max_num_predicates)
    test_data_loader = DataLoader(test_data, batch_size=batch_size)

    preds_test, t_total = predict(model, test_data_loader)
    if verbose:
        print(f"Prediction time per test sample: {t_total / len(labels_test) * 1000}")

    # Unnormalize
    preds_test_unnorm = unnormalize_labels(preds_test, min_val, max_val)

    # Print metrics
    print(f"\nQ-Error, {workload_name}:")
    print_qerror(preds_test_unnorm, label)

    # Write predictions
    if write:
        file_name = f"results/predictions_{workload_name}.csv"
        os.makedirs(os.path.dirname(file_name), exist_ok=True)
        with open(file_name, "w") as f:
            for i in range(len(preds_test_unnorm)):
                f.write(f'{preds_test_unnorm[i]},{label[i]}\n')


In [9]:
print('Original (recreated and retrained) MSCN from Kipf et. al.:\n')
start_time = time.time()
train_and_predict(testset='job-light', num_queries=5000, epochs=1000, batch_size=100, hid=256)
print(f'Total Time: {round((time.time()-start_time),4)} seconds')


Original (recreated and retrained) MSCN from Kipf et. al.:

Q-Error, job-light:
Median: 3.829080001743435
90th percentile: 79.58870873669316
95th percentile: 381.1589145561346
99th percentile: 937.5885201549474
Max: 1271.7475329481463
Mean: 44.07001456032248
Total Time: 217.4167 seconds


### Adjusted Data Encoding
Below shows the difference between the original MSCN implementation of predicate data encoding versus our featurized predicate encoding.

In [11]:
#### THE ORIGINAL CODE IS AVAILABLE FROM KIPF ET AL, mscn/utils.py ####
def encode_data(predicates, joins, column_min_max_vals, column2vec, op2vec, join2vec):
    predicates_enc = []
    joins_enc = []
    for i, query in enumerate(predicates):
        predicates_enc.append(list())
        joins_enc.append(list())
        for predicate in query:
            if len(predicate) == 3:
                # Proper predicate
                column = predicate[0]
                operator = predicate[1]
                val = predicate[2]
                norm_val = normalize_data(val, column, column_min_max_vals)

                pred_vec = []
                pred_vec.append(column2vec[column])
                pred_vec.append(op2vec[operator])
                pred_vec.append(norm_val)
                pred_vec = np.hstack(pred_vec)
            else:
                pred_vec = np.zeros((len(column2vec) + len(op2vec) + 1))
            predicates_enc[i].append(pred_vec)
            predicates_enc[i] = predicates_enc[i].flatten()
        
        for predicate in joins[i]:
            # Join instruction
            join_vec = join2vec[predicate]
            joins_enc[i].append(join_vec)
    return predicates_enc, joins_enc


In [12]:
#### OUR UPDATES TO KIPF ET AL DATA ENCODING SCHEMA ####
def encode_data_NEW(predicates, joins, column_min_max_vals, column2vec, op2vec, join2vec):
    predicates_enc = []
    joins_enc = []
    for i, query in enumerate(predicates):
        predicates_enc.append(list())
        joins_enc.append(list())
        for predicate in query:
            column = predicate[0]
            operator = predicate[1]
            val = predicate[2]
            norm_val = normalize_data(val, column, column_min_max_vals)
            #MAJOR FEATURIZATION CHANGES HERE
            col_onehot = column2vec[column]
            oper_onehot = op2vec[operator]
            pred_vec = np.zeros(len(col_onehot)*len(oper_onehot))
            for j in range(len(col_onehot)):
                if col_onehot==1:
                    pred_vec[3*j:3*j+3]=oper_onehot*norm_val

        predicates_enc[i].append(pred_vec)

        for predicate in joins[i]:
            # Join instruction
            join_vec = join2vec[predicate]
            joins_enc[i].append(join_vec)
    return predicates_enc, joins_enc


### Predicate Encoding Scheme Comparison
Our main insight on improving the featurization is as follows. Suppose I have a query with the following predicates:
$$(b<0.5) \wedge (d>0.2) \wedge (e=0.3)$$
on some set of predicates $\{a,b,c,d,e\}$ where we assume each normalized attribute ranges between $[0,1]$.
Assuming an upward limit of four predicates, the Kipf implementation would featurize this predicate set as follows:

```(Predicate on a)
[0 1 0 0 0 1 0 0 0.5]
 a b c d e < > = val
 -- AND --
(Predicate on d)
[0 0 0 1 0 0 1 0 0.2]
 a b c d e < > = val
 -- AND --
(Predicate on e)
[0 0 0 0 1 0 0 1 0.3]
 a b c d e < > = val
FINAL REPRESENTATION (assuming a four predicate maximum):
[0 1 0 0 0 1 0 0 0.5 0 0 0 1 0 0 1 0 0.2 0 0 0 0 1 0 0 1 0.3 0 0 0 0 0 0 0 0 0]
Final Length of Predicate Featurization: 36
(average of 7.2 values per table attribute)```

By contrast, we choose to featurize this predicate set as follows:

```[0 0 0   0.5 0 0   0 0 0     0 0.2 0    0 0 0.3 ]  Final Featurization
 < > =    <  > =   < > =      <  >  =    < >  =    Equality Operators
   a        b        c          d          e       Variables
FINAL REPRESENTATION:
[0 0 0 0.5 0 0 0 0 0 0 0.2 0 0 0 0.3]
Final Length of Predicate Featurization: 15
(constant average of 3 values per table attribute)```

There are a number of benefits to this featurization. First, there is no upward limit on the number of predicates that can be placed in a query. The predicate featurization length has no dependence on the number of predicates in a query. There is also no order dependence; that is, presumably


$$[\color{green}{\text{0, 1, 0, 0, 0, 1, 0, 0, 0.5,}} \color{red}{\text{0, 0, 0, 1, 0, 0, 1, 0, 0.2,}}   0, 0, 0, 0, 1, 0, 0, 1, 0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0]$$

and 

$$[ \color{red}{\text{0, 0, 0, 1, 0, 0, 1, 0, 0.2,}} \color{green}{\text{0, 1, 0, 0, 0, 1, 0, 0, 0.5,}}  0, 0, 0, 0, 1, 0, 0, 1, 0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0]$$

should map to an identical cardinality, and indeed be identical queries (we have merely switched the order of predicate operations), but have very different predicate featurized representations. It would appear that the MSCN is not flexible enough to recognize this difference. Even when aggregating over sets of predicates (as the MSCN can be adjusted to do), the improved predicate featurization still out-performs the previous implementation.

In [10]:
print('Retrained MSCN Architecture with Updated Featurization:\n')
start_time = time.time()
#Note: We use this altered function on the back end with other utis, and integrate accordingly.
train_and_predict_NEW(testset='job-light', num_queries=5000, epochs=1000, batch_size=100, hid=256)
print(f'Total Time: {round((time.time()-start_time),4)} seconds')

Retrained MSCN Architecture with Updated Featurization:

Q-Error job-light:
Median: 3.3707934686744982
90th percentile: 44.26868661918655
95th percentile: 197.39683996127513
99th percentile: 782.6566606666486
Max: 954.0733123971569
Mean: 41.41337581462835
Total Time: 210.5493 seconds


### Current Performance on JOB-Light
| Q-Error Metric (JOB-Light) | Kipf et. al. Results | Improved Featurization | Cardinality Sampling |
| ---- | ---- | ---- | ---- |
|Median | 3.82| __3.37__ | 4.55 |
|90th Percentile| 78.4| __44.3__ | 76.3 |
|95th Percentile|362| __197__ | 302 |
|Max|1110| __954__ | 271841 |
|Mean|57.9| __41.4__ | 4872.7 |
|Training Time on GPU (min)|3.6|3.5| __2.8__ |

For the cardinality sampling model, we also have a number of trials with different pre-sets to optimize performance. We have some indication that this method can surpass state-of-the-art, however we also notice a "long tail" of error at higher percentiles. More training queries may alleviate this, however this information is costly to generate.

| JOB-light       |M/A Trial 1|M/A Trial 2|M/A Trial 3|M/A Trial 4| MSCN (Benchmark) |
|-----------------|:-----------:|:-----------:|:-----------:|:-----------:|-------|
| Samples         |        5000 |        5000 |       10000 |       10000 | 10000 |
| Queries         |        1000 |        4400 |        1000 |        4400 | 90000 |
| Median          | 7.51 | 5.23 | 4.55 | 5.07 |  __3.82__ |
| 90th percentile | 255.5 | 215.6 | __76.3__ | 128.7 |  78.4 |
| 95th percentile | 607 | 397 | __303__ | 279.70 |   362 |
| 99th percentile | 67952 | 2062 | 122114 | 1333 |   __927__ |
| Max             | 151023 | 2609 |    271841 | 1931 |  __1110__ |
| Mean            | 2758.1 | 117.2 | 4872.7 | 78.8 |  __57.9__ |