In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd

from cardinality_estimation.featurizer import Featurizer
from query_representation.query import load_qrep

import glob
import random
import os
import json
import time
from collections import defaultdict

In [None]:
#TRAINDIR = os.path.join(os.path.join("", "queries"), "mlsys1-train")
#TESTDIR = os.path.join(os.path.join("", "queries"), "mlsys1-train")
QDIR = os.path.join(os.path.join("", "queries"), "imdb-unique-plans")

In [None]:
def load_qdata(fns):
    qreps = []
    for qfn in fns:
        qrep = load_qrep(qfn)
        # TODO: can do checks like no queries with zero cardinalities etc.
        qreps.append(qrep)
        template_name = os.path.basename(os.path.dirname(qfn))
        qrep["name"] = os.path.basename(qfn)
        qrep["template_name"] = template_name
    return qreps

def get_query_fns(basedir, template_fraction=1.0):
    fns = []
    tmpnames = list(glob.glob(os.path.join(basedir, "*")))
    assert template_fraction <= 1.0
    
    for qi,qdir in enumerate(tmpnames):
        if os.path.isfile(qdir):
            continue
        template_name = os.path.basename(qdir)
        # let's first select all the qfns we are going to load
        qfns = list(glob.glob(os.path.join(qdir, "*.pkl")))
        qfns.sort()
        num_samples = max(int(len(qfns)*template_fraction), 1)
        random.seed(1234)
        qfns = random.sample(qfns, num_samples)
        fns += qfns
    return fns

In [None]:
train_qfns = get_query_fns(QDIR, template_fraction = 1.0)
trainqs = load_qdata(train_qfns)

In [None]:
allconstants = defaultdict(set)
constantmaxs = defaultdict(int)

In [None]:
for query in trainqs:
    for node in query["join_graph"].nodes():
        for ci, col in enumerate(query["join_graph"].nodes()[node]["pred_cols"]):
            consts = query["join_graph"].nodes()[node]["pred_vals"][ci]
            for const in consts:
                allconstants[col].add(const)
            if constantmaxs[col] < len(consts):
                constantmaxs[col] = len(consts)

In [None]:
import string

def preprocess_word(word, exclude_nums=False, exclude_the=False,
        exclude_words=[], min_len=0):
    word = str(word)
    # no punctuation
    exclude = set(string.punctuation)
    # exclude the as well
    if exclude_the:
        exclude.add("the")
    if exclude_nums:
        for i in range(10):
            exclude.add(str(i))

    # exclude.remove("%")
    word = ''.join(ch for ch in word if ch not in exclude)

    # make it lowercase
    word = word.lower()
    final_words = []

    for w in word.split():
        if w in exclude_words:
            continue
        if len(w) < min_len:
            continue
        final_words.append(w)

    return " ".join(final_words)

In [None]:
from gensim.models import Word2Vec
model_name = "./sampled_data.bin"
model = Word2Vec.load(model_name)
wv = model.wv
print(model.get_latest_training_loss())
print(model)
del model

In [None]:
found = 0
not_found = 0
allvectors = {}

for k,allvals in allconstants.items():
    #print(k, len(allvals), constantmaxs[k])
    for vals in allvals:
        vals = preprocess_word(vals)
        vecs = []
        for subval in vals:
            if subval in wv:
                vecs.append(wv[subval])
        if len(vecs) > 0:
            found += 1
            valkey = k + str(vals)
            valvec = np.sum(np.array(vecs), axis=0)
            #assert valkey not in allvectors
#             if valkey in allvectors:
#                 print(valkey)
            allvectors[valkey] = valvec
        else:
            not_found += 1
            


print(found, not_found)

In [None]:
np.sum(vecs, axis=0)

In [None]:
list(allvectors.keys())[0:10]

In [None]:
import pickle
#dump = pickle.dumps(allvectors)
#d = pickle.loads(dump)
#print(d.keys())
with open('embeddings1.pkl', 'wb') as handle:
    pickle.dump(allvectors, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('embeddings1.pkl', 'rb') as handle:
    b = pickle.load(handle)

#print(b.keys())
list(b.keys())[0:10]