In [None]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
%matplotlib inline

import itertools
from contextlib import nullcontext
import os, sys
sys.path.append('./code2vec')
sys.argv = ['test.py']
sys.argc = len(sys.argv)

from code2vec.config import Config
from code2vec.code2vec import load_model_dynamically

from data.dataloader import MatrixProgramEmbeddingDataset

from data.gen_matrix_progs import gen_matrix_progs_col_stripe as gen_progs

from nets.matrix_predictor_model import MatrixPredictor

In [None]:
# Configs for generating the programs to be used for evaluation
N = 16
prog_config = \
dict(N=N,
     num_programs=8,
     num_statements=3,
     stripe_size=6,
     stripe_chance=0.25
     )

In [None]:
progs = list(gen_progs(**prog_config))

In [None]:
# Set up the Code2Vec model

MODEL_LOAD_PATH = '../data/java14m_model/models/java14_model/saved_model_iter8.release'
sys.argv.extend(['--load', MODEL_LOAD_PATH, '--export_code_vectors'])

config = Config(set_defaults=True, load_from_args=True, verify=True)
c2v_model = load_model_dynamically(config)

In [None]:
# Create Torch dataset
all_data = MatrixProgramEmbeddingDataset(list(progs),
                                         {"config": config,
                                          "model": c2v_model,
                                          },
                                         {"base_array": np.zeros((N, N))}
                                         )
dataloader = torch.utils.data.DataLoader(all_data, batch_size=prog_config["num_programs"], shuffle=False)

In [None]:
# Load trained model

output_dir = 'models/reconstruction/'
checkpoint = torch.load("/path/to/model.pt", map_location=torch.device('cpu'))
model = MatrixPredictor(config.CODE_VECTOR_SIZE, N)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
# Collect model outputs and ground truths into a plottable image
with torch.no_grad():
    code_vectors, matrices = next(iter(dataloader))
    outputs = model(code_vectors)
    all_tensors = []
    for i in range(outputs.shape[0]):
        all_tensors.append(matrices[i])
    for i in range(outputs.shape[0]):
        all_tensors.append(outputs[i])
    out_grid = torchvision.utils.make_grid(torch.cat([outputs.view(-1, 1, N, N)]*3, dim=1), pad_value=1.0)
    gt_grid = torchvision.utils.make_grid(torch.cat([matrices.view(-1, 1, N, N)]*3, dim=1), pad_value=1.0)
    all_grid = torchvision.utils.make_grid(torch.cat([torch.stack(all_tensors).view(-1, 1, N, N)]*3, dim=1), pad_value=1.0, nrow=8)

# Plot matrix reconstructions

* The top row shows the matrices generated by the programs (which are not given as inputs to the network).
* The bottom row visualizes the predicted matrices.

In [None]:
fig = plt.gcf()
fig.set_size_inches(22, 8.5)
plt.imshow(np.transpose(all_grid.numpy(), (1, 2, 0)))

In [None]:
# Print the programs used to generate the results above

for prog in list(progs):
    print(prog, '\n========')