In [1]:
import sys
sys.path.append("../")
from utilities.data import BPEVocabulary
from models.Transformer.TransformerLightning import TransformerLightning
import torch
import os
from utilities.C2SDataSet import C2SDataSet
from argparse import Namespace
from pathlib import Path
from torch.utils.data import DataLoader
from apex import amp

In [2]:
from utilities.data import BPEVocabulary
from models.Transformer.TransformerLightning import TransformerLightning
import torch
from utilities.C2SDataSet import C2SDataSet
from argparse import Namespace
from torch.utils.data import DataLoader
from torch.nn import functional as F

def create_dataset_from_hparams(filepath, hparams, vocabulary, variables=False):
    return C2SDataSet(
        filepath,
        hparams.max_contexts,
        vocabulary,
        hparams.subtoken_len,
        hparams.ast_len,
        hparams.target_len,
        shuffle=True,
        variable_only_filter=variables,
        line_cache=None,
    )

device = torch.device("cuda")
def batch_to_cuda(batch):
    label, input_tensors = batch
    label = label.cuda()
    input_tensors = (t.to(device) for t in input_tensors)
    
    return (label, input_tensors)


def run_inference(filepath, model, hparams, variables=False):
    dataset = create_dataset_from_hparams(filepath, hparams, model.vocabulary, variables=variables)
    dataloader = DataLoader(dataset, batch_size=1)

    for batch in dataloader:

        batch = batch_to_cuda(batch)

        predictions, raw = model.predict(batch)
        predictions = predictions.tolist()
        
        labels = batch[0].tolist()

        prediction = model.vocabulary.decode_target(predictions[0])
        label = model.vocabulary.decode_target(labels[0])

        print(f"{label} -> {prediction}")


def load_model(path, model_name, root = None):
    if not root:
        root = path
        
    checkpoint = torch.load(os.path.join(path, model_name))
    hparams = Namespace(**checkpoint["hparams"])
    # update vocabulary location
    for k, v in hparams.__dict__.items():
        if isinstance(v, str) and (v.endswith("-vocab.json") or v.endswith("-merges.txt") or os.path.basename(v) == "node_dict.pkl"):
            setattr(hparams, k, os.path.join(root, os.path.basename(v)))
#             print(os.path.join(path, os.path.basename(v)))
            
    model = TransformerLightning(hparams)
    model.load_state_dict(checkpoint["state_dict"])

    model.eval()
    model.cuda()
    return model, hparams

In [3]:
def create_dataset_from_hparams(filepath, hparams, vocabulary):
    return C2SDataSet(
        filepath,
        hparams.max_contexts,
        vocabulary,
        hparams.subtoken_len,
        hparams.ast_len,
        hparams.target_len,
        shuffle=True,
        variable_only_filter=True,
        line_cache=None,
    )

def run_inference(filepath, model, hparams):
    dataset = create_dataset_from_hparams(filepath, hparams, model.vocabulary)
    dataloader = DataLoader(dataset, batch_size=1)

    for batch in dataloader:

        (
            labels,
            (start, end, path, masks, start_lengths, end_lengths, ast_path_lengths),
        ) = batch
        labels = labels.cuda()
        start = start.cuda()
        end = end.cuda()
        path = path.cuda()
        masks = masks.cuda()
        start_lengths = start_lengths.cuda()
        end_lengths = end_lengths.cuda()
        ast_path_lengths = ast_path_lengths.cuda()
        batch = (
            labels,
            (start, end, path, masks, start_lengths, end_lengths, ast_path_lengths),
        )

        predictions, raw = model.predict(batch)
        # print()
        # print(f"raw: {raw.shape}")
        # print_topk(raw, labels, index_to_target)
        predictions = predictions.tolist()
        labels = labels.tolist()


        prediction = model.vocabulary.decode_target(predictions[0])

        label = model.vocabulary.decode_target(labels[0])
        # print(predictions[0])
        # print(label[0])

        print(prediction)
        print(label)

In [4]:
# input: B x * x ... x *
# dim: 0 <= scalar
# index: M
def batched_index_select(input, dim, index):
    views = [1 if i != dim else -1 for i in range(len(input.shape))]
    expanse = list(input.shape)
    expanse[dim] = -1
    index = index.view(views).expand(expanse)
    # making the first dim of output be B
    return torch.cat(torch.chunk(torch.gather(input, dim, index), chunks=index.shape[0], dim=dim), dim=0)

def incremental_decode(model, batch, current_decoder_output):
    with torch.no_grad():
        (start, end, path, masks, start_lengths, end_lengths, ast_path_lengths) = batch

    return model(start,
                    end,
                    path,
                    start_lengths,
                    end_lengths,
                    ast_path_lengths,
                    masks, current_decoder_output)

from einops import rearrange

def beam_search(model, batch, beam_size=3):
    label, _ = batch
    batch_size = label.shape[0]
    max_seq_len = label.shape[-1]
    batch = [item.cuda() for item in batch[1]]

    sos_idx = model.vocabulary.encode_target(model.vocabulary.SOS_TOKEN)
    eos_idx = model.vocabulary.encode_target(model.vocabulary.EOS_TOKEN)

    current_decoder_outputs = torch.tensor([sos_idx] * batch_size).reshape(batch_size, 1, 1).long().cuda()
    current_decoder_scores = torch.zeros(batch_size, 1).cuda()
    # current_decoder_output: (batch_size, beam_size, seq_len)

    print(f"current_decoder_outputs: {current_decoder_outputs.shape}")
    while current_decoder_outputs.shape[-1] < max_seq_len:
        # print()
        # print(f"original batch size: {batch[0].shape}")
        # print(f"current_decoder_outputs: {current_decoder_outputs.shape}")
        # current_decoder_output: batch_size, current_beam_size, seq_len
        current_beam_size = current_decoder_outputs.shape[1]
        seq_len = current_decoder_outputs.shape[-1]
        # print(current_beam_size)
        current_decoder_outputs = current_decoder_outputs.reshape(-1, current_decoder_outputs.shape[-1])
        #current_decoder_output: batch_size * current_beam_size, seq_len
        current_batch = list(
            map(lambda x: x.repeat_interleave(current_beam_size, dim=0), batch)
        )
        current_increment = incremental_decode(
            model, 
            current_batch, 
            current_decoder_outputs)

        predictions = F.log_softmax(current_increment, dim=2)[:, -1, :]
        predictions = predictions.reshape(batch_size, -1)
        # print(f"{predictions.shape}")

        topk = predictions.topk(beam_size, dim=-1)
        topk_scores = topk.values.reshape(-1, 1)
        topk_predictions = topk.indices.reshape(-1, 1)

        # print(f"topk_scores: {topk_scores.shape}")

        current_decoder_scores_repeated = current_decoder_scores
        if current_decoder_scores.shape[0] < topk_scores.shape[0]:
            current_decoder_scores_repeated = current_decoder_scores.repeat_interleave(beam_size).cuda()

        candidate_scores = current_decoder_scores_repeated.reshape(-1) + topk_scores.reshape(-1)
        # print(f"candidate: {candidate_scores.shape}")
        candidate_scores = candidate_scores.reshape(-1, 1)

        current_decoder_outputs_repeated = current_decoder_outputs
        if current_decoder_outputs.shape[0] < topk_predictions.shape[0]:
            current_decoder_outputs_repeated = current_decoder_outputs.repeat_interleave(beam_size, dim=0).cuda()

        candidate_outputs = torch.cat([current_decoder_outputs_repeated, topk_predictions], dim=1)
        assert current_decoder_outputs_repeated.shape[0] == topk_predictions.shape[0]


        # now we have both candidate scores and outputs, we need to prune paths

        candidate_scores = candidate_scores.reshape(batch_size, -1)
        # print(candidate_scores.shape)
        indices_to_keep = candidate_scores.topk(beam_size, dim=-1).indices
        pruned_scores = torch.stack([candidate_scores[i, indices_to_keep[i]] for i in range(indices_to_keep.shape[0])])

        assert (pruned_scores == candidate_scores).all()
        pruned_scores = pruned_scores.reshape(-1, 1)
        assert pruned_scores.shape[0] == beam_size * batch_size

        candidate_outputs = candidate_outputs.reshape(batch_size, -1, seq_len+1)
        # print(candidate_outputs.shape)
        pruned_outputs = torch.stack([candidate_outputs[i, indices_to_keep[i], :] for i in range(indices_to_keep.shape[0])])
        assert pruned_outputs.shape[0] == batch_size
        assert pruned_outputs.shape[1] == beam_size

        current_decoder_scores = pruned_scores
        current_decoder_outputs = pruned_outputs
        assert current_decoder_outputs.shape[1] <= beam_size
    # break
    return current_decoder_outputs, current_decoder_scores.reshape(batch_size, beam_size)





In [15]:
data_root = "../data/inference-test/"
ckp = "../lightning_logs/version_11/checkpoints/epoch=9.ckpt"
# %cd ..
# _, hparams = load_model("./lightning_logs/version_11/checkpoints/", "epoch=9.ckpt", root="../data/var-name/")
model = TransformerLightning.load_from_checkpoint(checkpoint_path="./lightning_logs/version_11/checkpoints/epoch=9.ckpt")
model.freeze()
root = "/media/devjeetroy/ss/testing-project/survey/internal/code_samples/processed/keycloak/variables/"
java_files = list(map(str, Path(root).glob("./**/*.java")))
java_files = [f for f in java_files if os.path.isfile(f)]
java_files.sort()

batch_size = 3
beam_size = 5

def create_dataset_from_hparams(filepath, hparams, vocabulary, variables=False):
    return C2SDataSet(
        filepath,
        hparams.max_contexts,
        vocabulary,
        hparams.subtoken_len,
        hparams.ast_len,
        hparams.target_len,
        shuffle=True,
        variable_only_filter=variables,
        line_cache=None,
    )

dataset = create_dataset_from_hparams(java_files[55], hparams, model.vocabulary, variables=True)
dataloader = DataLoader(dataset, batch_size=batch_size)
for batch in dataloader:
    break


line cache not found


In [16]:
labels, _ = batch
with torch.no_grad():
    outputs, scores = beam_search(model.cuda(), batch, beam_size=beam_size)

RuntimeError: CUDA error: device-side assert triggered

In [17]:
# model, hparams = load_model(v_models, "epoch=9 new.ckpt")
# model = model.to(device)

In [18]:
labels = labels.tolist()
outputs = outputs.tolist()

for label, predictions in zip(labels, outputs):
    # print(labels)
    print(f"Label: {model.vocabulary.decode_target(label)}")
    for prediction in predictions:
        print(f"Prediction: {model.vocabulary.decode_target(prediction)}")
        print()

NameError: name 'outputs' is not defined

In [20]:

def run_inference(filepath, model, hparams, variables=False):
    dataset = create_dataset_from_hparams(filepath, hparams, model.vocabulary, variables=variables)
    dataloader = DataLoader(dataset, batch_size=1)

    for batch in dataloader:

#         batch = batch_to_cuda(batch)

        predictions, raw = model.predict(batch)
        predictions = predictions.tolist()
        
        labels = batch[0].tolist()

        prediction = model.vocabulary.decode_target(predictions[0])
        label = model.vocabulary.decode_target(labels[0])

        print(f"{label} -> {prediction}")

# model.as_type
for i  in range(len(java_files)):
    if "variables" not in java_files[i].split("/"):
        continue
    print(i)
    class_name = java_files[i].split("/")[-2:][0]
    # method_name = os.path.basename(java_files[i])
    class_method = f"{class_name}$"
    print(f"{class_method}")
    run_inference(java_files[i], model, hparams, variables=True)

0
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri details1  uri
keycloak uri -> sb  uri 
keycloak uri -> uri  uri 
object array0  -> args  data 
1
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> builder  builder1 
2
KeycloakUriBuilder_ESTest.java$
line cache not found
string array0  -> args  args 
e  -> e e 
3
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri  uri 
4
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri  details 
5
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri details1  uri
u ri0  -> uri  uri 
6
KeycloakUriBuilder_ESTest.java$
line cache not found
string0  -> template  xml 
keycloak uri -> builder  uri 
keycloak uri -> builder  
object array0  -> args  data 
7
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri  uri 
8
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> builder  uri 
integer0  -> expected  path 
9
Keyc

line cache not found
keycloak uri -> builder  details 
e  -> e  ex 
68
KeycloakUriBuilder_ESTest.java$
line cache not found
string array0  -> args  args 
keycloak uri -> uri details1  uri
u ri0  -> uri2  uri 
69
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri details1  uri
e  -> ex  ex 
70
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> builder  details 
e  -> ex  ex 
71
KeycloakUriBuilder_ESTest.java$
line cache not found
boolean0  -> result  result 
72
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> details  details 
e  -> ex  ex 
73
KeycloakUriBuilder_ESTest.java$
line cache not found
e  -> ex  ex 
74
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri details1  uri
u ri0  -> uri  uri 
e  -> e  ex 
75
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri details1  uri
integer0  -> a  
e  -> ex  ex 
76
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri deta

e  -> e  e 
131
KeycloakUriBuilder_ESTest.java$
line cache not found
string array0  -> args  args 
keycloak uri -> uri details1  uri
u ri0  -> uri2  case 
132
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri  uri 
keycloak uri -> uri details1  uri
u ri0  -> ws uri  uri
u ri1  -> alt key uri 
133
KeycloakUriBuilder_ESTest.java$
line cache not found
string array0  -> args  args 
keycloak uri -> uri details1  uri
u ri0  -> uri  uri 
134
KeycloakUriBuilder_ESTest.java$
line cache not found
u ri2  -> alt key uri 
u ri0  -> uri  uri 
u ri1  -> alt key uri 
135
KeycloakUriBuilder_ESTest.java$
line cache not found
keycloak uri -> uri  uri 
u ri0  -> uri  uri 
136
KeycloakUriBuilder_ESTest.java$
line cache not found
e  -> e  e 
137
KeycloakUriBuilder_ESTest.java$
line cache not found
u ri2  -> alt key uri 
u ri0  -> ws uri  uri
u ri1  -> alt key uri 
138
KeycloakUriBuilder_ESTest.java$
line cache not found
u ri1  -> alt key uri 
keycloak uri -> uri builder  uri
u ri0  ->

line cache not found
e  -> e  e 
195
UriUtils_ESTest.java$
line cache not found
string0  -> s  
196
UriUtils_ESTest.java$
line cache not found
string0  -> s ids 
197
UriUtils_ESTest.java$
line cache not found
ssl required0  -> ssl code  ssl
198
UriUtils_ESTest.java$
line cache not found
ssl required0  -> ssl code  ssl
199
UriUtils_ESTest.java$
line cache not found
ssl required0  -> ssl ssl ssl ssl ssl
200
UriUtils_ESTest.java$
line cache not found
ssl required0  -> ssl  ssl 
e  -> e  ex 
201
UriUtils_ESTest.java$
line cache not found
multivalued hash map0  -> map  data 
202
UriUtils_ESTest.java$
line cache not found
multivalued hash map0  -> map  map 
203
UriUtils_ESTest.java$
line cache not found
u ri0  -> uri  uri 
string0  -> result  url 
204
UriUtils_ESTest.java$
line cache not found
string1  -> expected  name 
u ri0  -> uri  uri 
e  -> e  e 
string0  -> base uri  url
205
UriUtils_ESTest.java$
line cache not found
e  -> e  e 
206
UriUtils_ESTest.java$
line cache not found
boolean0 