In [None]:
import os
import sys
sys.path.append(os.pardir)

In [None]:
%matplotlib inline

import common_libs.utilities as ut
import data.data_cost as dt
import models.graph_models as md

import pandas as pd
from matplotlib import pyplot as plt
import torch

In [None]:
cnx = ut.create_connection()

In [None]:
times = pd.read_sql('SELECT * FROM aug_times', cnx).set_index('time_id')
augs = pd.read_sql('SELECT aug_id, code_id FROM functional_unit_augmentation', cnx).set_index('aug_id')

In [None]:
word2id = torch.load('../inputs/embeddings/code_delim.emb')[1]

def get_data_item(code_token):
    tokens = list(map(int, code_token.split(',')))
    tokens_it = iter(tokens)
    instrs = []
    xs = []
    while True:
        x = []
        try:
            opcode = int(next(tokens_it))
            x.append(opcode)
        except StopIteration:
            break
        assert next(tokens_it) == -1
        x.append(-1)
        srcs = []
        while True:
            src = int(next(tokens_it))
            x.append(src)
            if src == -1:
                break
            else:
                srcs.append(src)
        dsts = []
        while True:
            dst = int(next(tokens_it))
            x.append(dst)
            if dst == -1:
                break
            else:
                dsts.append(dst)
        instrs.append(ut.Instruction(opcode, srcs, dsts, 1))
        xs.append(list(map(lambda z: word2id.get(z, 0), x)))
    return dt.DataItem(xs, 0, ut.BasicBlock(instrs))

In [None]:
graph = md.GraphNN(embedding_size=256, hidden_size=256, num_classes=1, use_residual=True)
graph.set_learnable_embedding(mode='none', dictsize=1337)
graph.load_state_dict(torch.load('../saved/trained_4x4.mdl')['model'])

def plot_times_of_code_id(code_id):
    plt.figure()
    t = times[(times['code_id'] == code_id) & (times['time'] > 0)]['time']
    plt.title('DATA: functional Unit Augmentation: code_id={}'.format(code_id))
    plt.xlabel('Repetitions')
    plt.ylabel('Actual execution time')
    plt.scatter(range(len(t)), t)

def plot_preds_of_code_id(code_id):
    plt.figure()
    cnx = ut.create_connection()
    codes = pd.read_sql('SELECT code_token FROM functional_unit_augmentation WHERE code_id={}'.format(code_id), cnx)
    t = [graph(get_data_item(codes.iloc[i].code_token)).item() for i in range(len(codes))]
    plt.title('PREDICTION: functional Unit Augmentation: code_id={}'.format(code_id))
    plt.xlabel('Repetitions')
    plt.ylabel('Predicted execution time')
    plt.scatter(range(len(t)), t)

In [None]:
interesting_idxs = [468404, 470781, 470762, 469803, 931799, 931801, 931803, 467076, 467731]
interesting_idx = interesting_idxs[5]

In [None]:
plot_times_of_code_id(interesting_idx)
plot_preds_of_code_id(interesting_idx)

In [None]:
other_idxs = times[times['time'] > 0].groupby('code_id').count().sort_values('arch', ascending=False).index
print(other_idxs)

In [None]:
res = pd.read_sql('SELECT code_intel FROM functional_unit_augmentation WHERE code_id={}'.format(single_good_idx), cnx)
for (_, i) in res.iterrows():
    print(i['code_intel'])
    print('-'*80)

In [None]:
idx_iter = iter(idxs)

In [None]:
i = next(idx_iter)
print(i)
plt.title('code_id: {}'.format(i))
plot_times_of_code_id(i)

In [None]:
res = pd.read_sql('SELECT code_intel, code_token FROM functional_unit_augmentation WHERE code_id={}'.format(i), cnx)
print(list(res.iterrows())[-1][1]['code_intel'])

In [None]:
graph(get_data_item(res.iloc[0].code_token))

In [None]:
data = dt.load_dataset('../inputs/embeddings/code_delim.emb', '../saved/time_skylake_test.data')

In [None]:
print(data.data[1].x[2])
print(list(map(word2id.get, [-1])))

# for i in data.data[1].block.instrs:
#     print(i)

In [None]:
data = dt.DataInstructionEmbedding()
data.raw_data = raw_data

In [None]:
data.set_embedding('../inputs/embeddings/code_delim.emb')
data.read_meta_data()
data.prepare_data()

In [None]:
data.raw_data

In [None]:
data.get_timing_data(cnx, 1)