## Setup (ignore all this stuff, scroll down to Interpretability)

In [None]:
from __future__ import print_function
import os
import sys
sys.path.append(os.path.join(os.environ['ITHEMAL_HOME'], 'learning', 'pytorch'))

In [None]:
%matplotlib inline

import numpy as np
import torch
import data.data_cost as dt
import common_libs.utilities as ut
import models.train as tr
import models.graph_models as md
import models.losses as ls
import random
from tqdm import tqdm

In [None]:
data = dt.load_dataset('/home/ithemal/ithemal/learning/pytorch/inputs/embeddings/code_delim.emb', data_savefile='/home/ithemal/ithemal/learning/pytorch/saved/time_skylake_1217.data')

In [None]:
for item in data.data:
    item.block.remove_edges()

In [None]:
embedding_size = data.final_embeddings.shape[1]
model = md.GraphNN(embedding_size, 256, 1, False)
model.set_learnable_embedding(mode = 'none', dictsize = max(data.word2id) + 1, seed = data.final_embeddings)

In [None]:
train = tr.Train(model, data, 4)
train.loss_fn = ls.mse_loss
train.print_fn = train.print_final
train.correct_fn = train.correct_regression
train.num_losses = 1

In [None]:
state_dict = torch.load('../saved/edges_none_01-07-19_06:42:33.mdl')
model.load_state_dict(state_dict['model'])

In [None]:
def predict(item):
    model.remove_refs(item)
    model.init_bblstm(item)
    roots = item.block.find_roots()

    root_hidden = []
    for instr in roots:
        token_embeds_lstm = torch.FloatTensor(instr.tokens).unsqueeze(1)
        _, (ins_embed, _) = model.lstm_token(token_embeds_lstm, model.init_hidden())
        _, (ins_hidden, _) = model.lstm_ins(ins_embed, model.init_hidden())
        root_hidden.append(ins_hidden.squeeze())

    final_hidden = root_hidden[0]
    for hidden in root_hidden[1:]:
        final_hidden = model.reduction(final_hidden,hidden)
    pred = model.linear(final_hidden).squeeze()
    model.remove_refs(item)
    
    return pred, root_hidden

In [None]:
train.correct = 0
total_loss = 0

for datum in data.test:
    pred, _ = predict(datum)
    y = torch.FloatTensor([datum.y]).squeeze()
    total_loss += ls.mse_loss(pred, y)[0].item()
    train.correct_regression(pred, y)

print('Test loss: {:.2f}'.format(total_loss / float(len(data.test))))
print('Test accuracy: {:.2f}'.format(train.correct / float(len(data.test))))

## Interpretability

Since this is a simple model (LSTM cell on each instruction embedding, `max` to combine instruction embeddings, then a linear layer), we can get some interpretable results:
- `desire`, the prediction of the given instruction if it were the only instruction in the block
- `weight`, the percentage of the (absolute value) weight in the linear layer that is multiplied by each slot in the final vector where the instruction won the `max`
- `total contrib`, the actual contribution of the instruction to the final prediction (the slots where it won `max` times the weights of those slots)

In [None]:
item = data.test[21]

pred, root_hidden = predict(item)

total_weight = model.linear.weight.data.squeeze().abs().sum()

for i in range(len(root_hidden)):
    max_vals, max_idxs = torch.stack(root_hidden).max(dim=0)
    idxs = np.where(max_idxs.data.numpy() == i)
    i_vals = max_vals[idxs]
    i_weights = model.linear.weight.data.squeeze()[idxs]
    i_weight_perc = 100 * i_weights.abs().sum() / total_weight
    i_w_contrib = (i_weights * i_vals).sum()
    i_desire = (model.linear(root_hidden[i])).squeeze().sum().item()
    print('{:<60}: desire {:6.2f}, weight {:2.0f}%, total contrib {:6.2f}'.format(
        item.block.instrs[i],
        i_desire,
        i_weight_perc,
        i_w_contrib,
    ))

print('\npred: {:.2f}, actual: {:.2f}'.format(
    pred.item(),
    item.y,
))

### What is the best we could do if we were to have a constant prediction

In [None]:
import itertools
from tqdm import tqdm

dataset = data.train

starts = [None] * len(dataset)
ends = [None] * len(dataset)

for i, datum in enumerate(dataset):
    starts[i] = datum.y * 0.75
    ends[i] = datum.y * 1.25

starts.sort()
ends.sort()

xs = []
ys = []

max_val = 0
max_count = 0
curr_count = 0
for val, typ in sorted(itertools.chain(zip(starts, itertools.repeat('start')), zip(ends, itertools.repeat('end')))):
    if typ == 'start':
        curr_count += 1
        if curr_count > max_count:
            max_count = curr_count
            max_val = val
    elif typ == 'end':
        curr_count -= 1
    xs.append(val)
    ys.append(curr_count)

In [None]:
from matplotlib import pyplot as plt
plt.xlabel('Time prediction')
plt.ylabel('#Correct on prediction')
plt.title('Best single prediction (learned from train)')
plt.plot(xs, ys)
plt.show()

In [None]:
m_prediction = xs[np.argmax(ys)]
actual = np.array([datum.y for datum in data.test])
single_prediction = np.ones_like(actual) * m_prediction + 1e-2

In [None]:
n_correct = np.sum(np.abs(actual - single_prediction) / (single_prediction + 1e-3) * 100 < 25)
print('Correct: {}'.format(n_correct / float(len(data.test))))