In [1]:
# https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x7f1b301efe30>

In [9]:
id_to_nodes = {"Af":"FileScan FactTable", "Ad":"FileScan DimTable", "B":"BroadcastHashJoin", "C":"SortMergeJoin", "D":"Union","E":"End"}
id_to_operators = {"":"", "1":"Filter", "2":"Project", "3":"HashAggregate", "4":"Sort", "5":"TakeOrderedAndProject","6":"Expand","7":"Window"}

training_data = [
    ("AdB BC CB BC CE".split(), ["1-2", "2", "2-3-3-1-4", "2-4","2-5"]), #TPC-DS Q1
    ("AfD DB BB BC CE".split(), ["1-2", "", "2-3-3", "2-4","2-4"]), #TPC-DS Q2
    ("AfB BB BE".split(), ["1", "2", "2-3-3-5"]), # TPC-DS Q3
    ("AfC CB BC CC CC CC CC CE".split(), ["1-4","2","2-3-1-4","","2","2","2","2-5"]), # TPC-DS Q4
    ("AfC CD DB BB BD DE".split(), ["1-4","2","","2","2-3","3-5"]), # TPC-DS Q5
    ("AfB BC CB BC CE".split(), ["1", "2-4", "2", "2-4", "2-3-3-1-5"]), #TPC-DS Q6
    ("AfB BB BB BB BE".split(), ["1-2","2","2","2","2-3-5"]), #TPC-DS Q7
    ("AfC CB BE".split(), ["1-4","2-3","2-3-5"]), #TPC-DS Q8
    ("AfE".split(), ["1-2-3-2-2"]), #TPC-DS Q9
    ("AdB BC CC CB BB BE".split(), ["1-2","2-4","","1-2","2","2-3-5"]), #TPC-DS Q10
    ("AfC CB BC CC CC CE".split(), ["1-4","2","2-3-1-4","","2","2"]), #TPC-DS Q11
    ("AfB BB BE".split(), ["1","2","2-3-4-7-2"]), #TPC-DS Q12
    ("AfB BB BB BB BB BE".split(), ["1","2","2","2","2","2-3"]), #TPC-DS Q13
    #TPC-DS Q14a
    #TPC-DS Q14b
    ("AfC CC CB BE".split(), ["1-4","2-4","2","2-3"]), #TPC-DS Q15
    ("AfC CC CB BB BB BE".split(), ["1-4","2","","2","2","2-3-3"]), #TPC-DS Q16
    ("AfC CC CB BB BB BB BE".split(), ["1-4","2-4","2","2","2","2","2-3"]), #TPC-DS Q17
    ("AfB BC CB BC CB BB BE".split(), ["1-2","2-4","2","2-4","2","2","2-6-3"]), #TPC-DS Q18
    ("AdB BB BC CB BB BE".split(), ["1-2","2","2-4","2","2","2-3"]), #TPC-DS Q19
    ("AfB BB BE".split(), ["1","2","2-3-4-7-2"]), #TPC-DS Q20
    ("AfB BB BB BE".split(), ["1","2","2","2-3-1"]), #TPC-DS Q21
    ("AdB BB BB BE".split(), ["1-2","2","2","2-6-3-5"]), #TPC-DS Q22
    #TPC-DS Q23a
    #TPC-DS Q23b
    #TPC-DS Q24a
    #TPC-DS Q24b
    ("AfC CC CB BB BB BB BB BE".split(), ["1-4","2-4","2","2","2","2","2","2-3"]), #TPC-DS Q25
    ("AfB BB BB BB BE".split(), ["1-2","2","2","2","2-3-5"]), #TPC-DS Q26
    ("AfB BB BB BB BE".split(), ["1-2","2","2","2","2-6-3-5"]), #TPC-DS Q27
    #TPC-DS Q28 有subquery
    ("AfC CC CB BB BB BB BB BE".split(), ["1-4","2-4","2","2","2","2","2","2-3-5"]), #TPC-DS Q29
    ("AdB BB BC CC CB BE".split(), ["1-2","2","2-3-3-1-4","2-4","2","2-5"]), #TPC-DS Q30
    ("AdB BB BC CC CC CC CE".split(), ["1","2","2-3-4","2","","2","2-4"]), #TPC-DS Q31
    ("AdB BC CB BE".split(), ["1-2","2-3-1-4","2","2-3"]), #TPC-DS Q32
    ("AdB BB BB BD DE".split(), ["1-2","2","2","2-3","3-5"]), #TPC-DS Q33
    ("AdB BB BB BC CE".split(), ["1-2","2","2","2-3-1-4","2-4"]), #TPC-DS Q34
    ("AdB BC CC CC CB BB BE".split(), ["1-2","2-4","","","1-2","2","2-3-5"]), #TPC-DS Q35
    #---------33-----------
    ("AdB BB BB BE".split(), ["1-2","2","2","2-6-3-4-7-2-5"]), #TPC-DS Q36
    
    ("AdB BB BB BC CE".split(), ["1-2","2","2","4","2-3-5"]), #TPC-DS Q37
    ("AdB BC CC CC CE".split(), ["1-2","2-4","2-3-4","","2-3"]), #TPC-DS Q38
    #TPC-DS Q39a
    #TPC-DS Q39b
    ("AfC CB BB BB BE".split(), ["1-4","2","2","2","2-3"]), #TPC-DS Q40

    ("AfB BE".split(), ["1-2-3-1-2","1-3-2"]), #TPC-DS Q41
    ("AdB BB BE".split(), ["1-2","2","2-3-2"]), #TPC-DS Q42
    ("AdB BB BE".split(), ["1-2","2","2-3-2"]), #TPC-DS Q43
    #TPC-DS Q44
    ("AfB BB BB BB BB BE".split(), ["1","2","2","2","2","1-2-3-5"]), #TPC-DS Q45
    ("AdB BB BB BB BC CB BE".split(), ["1-2","2","2","2","2-3-4","2","2"]), #TPC-DS Q46
    ("AdB BB BB BC CC CE".split(), ["1","2","2","2-3-4-7-1-7-1-2-4","2","2-5"]), #TPC-DS Q47
    ("AfB BB BB BB BE".split(), ["1","2","2","2","2-3"]), #TPC-DS Q48
    ("AfC CB BD DE".split(), ["1-4","2","2-3-4-7-4-7-1-2","3-5"]), #TPC-DS Q49
    ("AfC CB BB BB BE".split(), ["1-4","2","2","2","2-3-5"]), #TPC-DS Q50

    ("AdB BC CE".split(), ["1-2","2-3-4-7-2-4","2-4-7-1-5"]), #TPC-DS Q51
    ("AdB BB BE".split(), ["1-2","2","2-3-5"]), #TPC-DS Q52
    ("AfB BB BE".split(), ["1-2","2","2-3-4-7-1-2-5"]), #TPC-DS Q53
    ("AfD DB BB BC CC CB BB BE".split(), ["1-2","","2","2-4","2-4","2","2","2-3-3-3-2"]),#TPC-DS Q54
    ("AdB BB BE".split(), ["1-2","2","2-3-5"]), #TPC-DS Q55
    ("AdB BB BB BD DE".split(), ["1-2","2","2","2-3","3-5"]), #TPC-DS Q56
    ("AfB BB BB BC CC CE".split(), ["1","2","2","2-3-4-7-1-7-1-2-4","2","2-5"]), #TPC-DS Q57
    #TPC-DS Q58
    ("AdB BB BB BC CE".split(), ["1","2-3","2","2-4","2-5"]), #TPC-DS Q59
    ("AdB BB BB BD DE".split(), ["1-2","2","2","2-3","3-5"]), #TPC-DS Q60
    
    # window
    # id_to_nodes = {"Af":"FileScan FactTable", "Ad":"FileScan DimTable", "B":"BroadcastHashJoin", 
    # "C":"SortMergeJoin", "D":"Union"}
    # id_to_operators = {"":"", "1":"Filter", "2":"Project", "3":"HashAggregate", 
    # "4":"Sort", "5":"TakeOrderedAndProject","6":"Expand","7":"Window"}
    ("AfB BB BB BB BB BB BB BE".split(), ["1-2","2","2","2","2","2","2-3","2"]), #TPC-DS Q61
    ("AfB BB BB BB BE".split(), ["1","2","2","2","2-3-5"]), #TPC-DS Q62
    ("AfB BB BB BE".split(), ["1-2","2","2","2-3-4-7-1-2-5"]), #TPC-DS Q63
    ("AfC CC CB BB BC CB BB BB BB BB BB BB BC CC CB BB BB BC CE".split(), ["1-4","2-3-1-2-4","2","2","2-4","2","2","2","2","2","2","2","2-4","2-4","2","2","2","2-3-4","2-4"]), #TPC-DS Q64
    ("AdB BB BB BC CE".split(), ["1-2","2-3-1","2","2-4","2-2"]), #TPC-DS Q65
    ("AdB BB BB BB BD DE".split(), ["1","2","2","2","2-3","3-2"]), #TPC-DS Q66
    ("AdB BB BB BE".split(), ["1-2","2","2","2-6-3-4-7-1"]), #TPC-DS Q67
    ("AdB BB BB BB BC CB BE".split(), ["1-2","2","2","2","2-3-4","2","2-5"]), #TPC-DS Q68
    ("AdB BC CC CC CB BB BE".split(), ["1-2","2-4","","","2","2","2-3-5"]), #TPC-DS Q69
    ("AdB BB BC CB BE".split(), ["1","2","2-3-4-7-1-2","","2-6-3-4-7-2"]), #TPC-DS Q70

    ("AdB BD DB BB BE".split(), ["1-2","2","","2","2-3-4"]), #TPC-DS Q71
    ("AdC CB BB BB BB BB BB BB BB BB BC CE".split(), ["1-4","2","2","2","2","2","2","2","2","2","2-4","2-3-5"]), #TPC-DS Q72
    ("AdB BB BB BC CE".split(), ["1-2","2","2","2-3-1-4","2-4"]), #TPC-DS Q73
    ("AfC CB BC CC CE".split(), ["1-4","2","2-3-1-4","2","2-5"]), #TPC-DS Q74
    ("AfB BB BC CD DC CE".split(), ["1-2","2","2-4","2","3-3-4","2-5"]), #TPC-DS Q75
    ("AfB BB BD DE".split(), ["1","2","2","3-5"]), #TPC-DS Q76
    ("AdB BB BC CD DE".split(), ["1-2","2","2-3-4","2","6-3-5"]), #TPC-DS Q77
    ("AdC CB BC CC CE".split(), ["1-4","1-2","2-3-4","2","1-2-5"]), #TPC-DS Q78
    ("AdB BB BB BC CE".split(), ["1-2","2","2","2-3-4","2-5"]), #TPC-DS Q79
    ("AfC CB BB BB BD DE".split(), ["1-4","2","2","2","2-3","6-3-5"]), #TPC-DS Q80

    ("AdB BB BC CC CC CE".split(), ["1-2","2","2-3-3-1-4","2-4","2-4","2-5"]), #TPC-DS Q81
    ("AdB BB BC CE".split(), ["1-2","2","2-4","2-3-5"]), #TPC-DS Q82
    ("AdB BB BB BC CC CE".split(), ["1-2","2","2","2-3-4","2","2-5"]), #TPC-DS Q83
    ("AfB BB BB BB BC CE".split(), ["1-2","2","2","2","2-4","2-5"]), #TPC-DS Q84
    ("AfC CB BB BB BB BB BB BE".split(), ["1-4","2","2","2","2","2","2","2-3-5"]), #TPC-DS Q85
    ("AdB BB BE".split(), ["1-2","2","6-3-4-7-2-5"]), #TPC-DS Q86
    ("AdB BC CC CC CE".split(), ["1-2","2-4","2-3","2","2-3"]), #TPC-DS Q87
    ("AfB BB BB BB BB BB BB BB BB BB BE".split(), ["1-2","2","2","2-3","","","","","","",""]), #TPC-DS Q88
    ("AfB BB BB BE".split(), ["1","2","2","2-3-4-7-1-2-5"]), #TPC-DS Q89
    ("AfB BB BB BB BE".split(), ["1-2","2","2","2-3","5"]), #TPC-DS Q90

    ("AfB BB BC CB BB BB BE".split(), ["1","2","2-4","2","2","2","2-3-4"]), #TPC-DS Q91
    ("AdB BC CB BE".split(), ["1-2","2-3-1-4","2","2-3"]), #TPC-DS Q92
    ("AfC CB BE".split(), ["1-4","2","2-3-5"]), #TPC-DS Q93
    ("AfC CC CB BB BB BE".split(), ["1-4","2","","2","2","2-3-3"]), #TPC-DS Q94
    ("AfC CC CC CB BB BB BE".split(), ["1-4","2","2","","2","2","2-3-3"]), #TPC-DS Q95
    ("AfB BB BB BE".split(), ["1-2","2","2","2-3"]), #TPC-DS Q96
    ("AdB BC CE".split(), ["1-2","2-3-4","2-3"]), #TPC-DS Q97
    ("AfB BB BE".split(), ["1","2","2-3-4-7-2-4-2"]), #TPC-DS Q98
    ("AfB BB BB BB BE".split(), ["1","2","2","2","2-3-5"]), #TPC-DS Q99

]
word_to_ix = {}
# For each words-list (sentence) and tags-list in each tuple of training_data
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:  # word has not been assigned an index yet
            word_to_ix[word] = len(word_to_ix)  # Assign each word with a unique index
print(word_to_ix)

tag_to_ix={}
for sent, tags in training_data:
    for tag in tags:
        if tag not in tag_to_ix:  # tag has not been assigned an index yet
            tag_to_ix[tag] = len(tag_to_ix)  # Assign each tag with a unique index
print(tag_to_ix)

ix_to_tag = {v:k for k,v in tag_to_ix.items()}
# print(ix_to_tag)




{'AdB': 0, 'BC': 1, 'CB': 2, 'CE': 3, 'AfD': 4, 'DB': 5, 'BB': 6, 'AfB': 7, 'BE': 8, 'AfC': 9, 'CC': 10, 'CD': 11, 'BD': 12, 'DE': 13, 'AfE': 14, 'AdC': 15, 'DC': 16}
{'1-2': 0, '2': 1, '2-3-3-1-4': 2, '2-4': 3, '2-5': 4, '': 5, '2-3-3': 6, '1': 7, '2-3-3-5': 8, '1-4': 9, '2-3-1-4': 10, '2-3': 11, '3-5': 12, '2-3-3-1-5': 13, '2-3-5': 14, '1-2-3-2-2': 15, '2-3-4-7-2': 16, '2-6-3': 17, '2-3-1': 18, '2-6-3-5': 19, '2-3-4': 20, '2-6-3-4-7-2-5': 21, '4': 22, '1-2-3-1-2': 23, '1-3-2': 24, '2-3-2': 25, '1-2-3-5': 26, '2-3-4-7-1-7-1-2-4': 27, '2-3-4-7-4-7-1-2': 28, '2-3-4-7-2-4': 29, '2-4-7-1-5': 30, '2-3-4-7-1-2-5': 31, '2-3-3-3-2': 32, '2-3-1-2-4': 33, '2-2': 34, '3-2': 35, '2-6-3-4-7-1': 36, '2-3-4-7-1-2': 37, '2-6-3-4-7-2': 38, '3-3-4': 39, '6-3-5': 40, '1-2-5': 41, '6-3-4-7-2-5': 42, '5': 43, '2-3-4-7-2-4-2': 44}


In [10]:

frequency_map={}
# for sent, tags in training_data:
#     for word in sent:
#         if word not in frequency_map:  # word has not been assigned an index yet
#             frequency_map[word] = len(word_to_ix)  # Assign each word with a unique index

for sent, tags in training_data:
    for (join_word, seq) in zip(sent, tags):
#         print(join_word, seq)
        if join_word not in frequency_map:
            frequency_map[join_word] = {}
        else:
            word_dict = frequency_map[join_word]
            if seq not in word_dict:
                word_dict[seq] = 1
            else:
                word_dict[seq] = word_dict[seq] +1

    
for word, word_dict in frequency_map.items():
    print("-------------")

#     print("word is {}".format(word))
    part1=""
    if word[0]=="A":
        part1=id_to_nodes[word[0:2]]

    else:
        part1=id_to_nodes[word[0]]
    part2 = id_to_nodes[word[-1]]
    converted_word = part1+"-"+part2
    print("converted_word is {}".format(converted_word))
    for seq_t, cnt in word_dict.items():
#         print("seq_t is {}, cnt is {}".format(seq_t, cnt))
        
        temp_list = seq_t.split("-")
        temp_ans = []
        for key in temp_list:
            temp_ans.append(id_to_operators[key])
        operator_list_str = "-".join(temp_ans)
        print("operator_list_str is {}, cnt is {}".format(operator_list_str, cnt))

        

-------------
converted_word is FileScan DimTable-BroadcastHashJoin
operator_list_str is Filter-Project, cnt is 34
operator_list_str is Filter, cnt is 5
-------------
converted_word is BroadcastHashJoin-SortMergeJoin
operator_list_str is Project-Sort, cnt is 22
operator_list_str is Project-HashAggregate-Filter-Sort, cnt is 7
operator_list_str is Project-HashAggregate-HashAggregate-Filter-Sort, cnt is 2
operator_list_str is Project-HashAggregate-Sort, cnt is 9
operator_list_str is Sort, cnt is 1
operator_list_str is Project-HashAggregate-Sort-Window-Filter-Window-Filter-Project-Sort, cnt is 2
operator_list_str is Project-HashAggregate-Sort-Window-Project-Sort, cnt is 1
operator_list_str is Project-HashAggregate-Sort-Window-Filter-Project, cnt is 1
-------------
converted_word is SortMergeJoin-BroadcastHashJoin
operator_list_str is Project, cnt is 29
operator_list_str is Project-HashAggregate, cnt is 1
operator_list_str is Filter-Project, cnt is 3
operator_list_str is , cnt is 4
--------

In [None]:
ans = torch.max(tag_scores,1)
ans_list = ans.indices.numpy().tolist()
print(ans_list)
for item in ans_list:
    temp_str = ix_to_tag[item]
