In [1]:
import os
import subprocess

# Change working directory to git repository root directory
result = subprocess.run(['git', 'rev-parse', '--show-toplevel'], stdout=subprocess.PIPE)
root_dir = result.stdout.decode('utf-8').replace('\n', '')
os.chdir(root_dir)

In [2]:
import numpy as np

from codegen_sources.model.translate import Translator
from codegen_sources.scripts.knnmt.load_functions import extract_functions, load_validation_functions
from codegen_sources.scripts.knnmt.knnmt import KNNMT

try:
    from codegen_sources.scripts.knnmt.datastore import add_to_datastore, train_datastore
except ImportError:
    pass

adding to path /pfs/data5/home/hd/hd_hd/hd_tf268/code-gen


In [3]:
DATASET_PATH = "data/test_dataset"

In [4]:
def output_sample(knnmt: KNNMT, translator: Translator, language_pair: str, data_index: int):
    src_language, tgt_language = language_pair.split("_")[0], language_pair.split("_")[1]

    # Get tokenized source function
    source_functions = extract_functions(f"{DATASET_PATH}/transcoder_test.{src_language}.tok")
    source = source_functions[data_index]
    source = translator.tokenize(source, src_language)
    generated = ""
    inputs = ""
    tc_prediction = ""

    # Predict target tokens using kNN-MT only
    while "</s>" not in generated and len(generated.split(" ")) < 200:
        tc_target, prediction, input = predict_next_token(knnmt, translator, src_language, tgt_language, source, generated)
        generated += " " + prediction
        inputs += f"{input[0]}\n{input[1]}\n"
        tc_prediction += " " + tc_target

    # Get original TransCoder translation
    translation = translator.translate(source, src_language, tgt_language)[0]
    translator.use_knn_store = False
    original_translation = translator.translate(source, src_language, tgt_language)[0]
    
    target_functions = extract_functions(f"{DATASET_PATH}/transcoder_test.{tgt_language}.tok")
    target = target_functions[data_index]
    target = translator.tokenize(target, tgt_language)
    
    print("\n\n\n\n\n")
    print(f"TC PREDICTION: '{tc_prediction[1:]}'")
    print(f"FINAL PREDICTION: '{generated[1:]}'")
    print(f"GROUND TRUTH: '{target} </s>'")
    print("\n\n")

In [5]:
def predict_next_token(
    knnmt: KNNMT,
    translator: Translator,
    src_language: str,
    tgt_language: str,
    source: str,
    generated: str
):
    # Get hidden feature representation of last decoder layer and ground truth target tokens
    decoder_features, _, targets, target_tokens, _, _ = translator.get_features(
        input_code=source,
        target_code=generated,
        src_language=src_language,
        tgt_language=tgt_language,
        predict_single_token=True,
        tokenized=True
    )

    # Retrieve k nearest neighbors including their distances and inputs
    language_pair = f"{src_language}_{tgt_language}"
    features = decoder_features[-1].unsqueeze(0)
    knns, distances, inputs = knnmt.get_k_nearest_neighbors(features, language_pair, with_inputs=True)
    tokens = [translator.get_token(id) for id in knns[0]]
    
    print("\n\n\n\n\n")
    print("=" * 100)
    print(f"SOURCE: '{source}'")
    print(f"GENERATED: '{generated[1:]}'\n")
    print("-" * 100)
    print(f"\nNEXT TC TARGET: '{target_tokens[-1]}'")
    print(f"PREDICTIONS: {tokens}")
    print(f"DISTANCES: {distances[0].astype(int)}\n")
    print("-" * 100)
    print(f"\nINPUT SOURCE: '{inputs[0][0][0][5:]}'")
    print(f"INPUT TARGET: '{inputs[0][0][1]}'")
    print("=" * 100)

    # import pdb; pdb.set_trace()
    return target_tokens[-1], tokens[0], inputs[0][0]

In [6]:
knnmt = KNNMT("out/knnmt/one_click_demo")

if not os.path.exists("out/knnmt/one_click_demo"):
    validation_functions = load_functions.load_validation_functions("data/test_dataset", language_pair="cpp_java")
    add_to_datastore(knnmt, { "cpp_java": validation_functions }, is_validation=True)
    train_datastore(knnmt, language_pair="cpp_java")

In [7]:
language_pair = "cpp_java"
src_language = language_pair.split("_")[0]
tgt_language = language_pair.split("_")[1]

translator_path = f"models/Online_ST_{src_language.title()}_{tgt_language.title()}.pth".replace("Cpp", "CPP")
translator = Translator(
    translator_path,
    "data/bpe/cpp-java-python/codes",
    global_model=True,
    knnmt_dir=knnmt.knnmt_dir
)

output_sample(knnmt, translator, language_pair, 442)

INFO - 10/23/22 20:51:09 - 0:00:05 - Reloading encoder from models/Online_ST_CPP_Java.pth ...
INFO - 10/23/22 20:51:15 - 0:00:10 - Reloading decoders from models/Online_ST_CPP_Java.pth ...
INFO - 10/23/22 20:51:16 - 0:00:12 - Number of parameters (encoder): 143279616
INFO - 10/23/22 20:51:16 - 0:00:12 - Number of parameters (decoders): 168482304
INFO - 10/23/22 20:51:16 - 0:00:12 - Number of decoders: 1

Loading codes from /pfs/data5/home/hd/hd_hd/hd_tf268/code-gen/data/bpe/cpp-java-python/codes ...
Read 50000 codes from the codes file.


Loading Faiss Index for 'cpp_java'
Loading Datastore Values for 'cpp_java'
Values:  (61823,)
Loading Datastore Inputs for 'cpp_java'
Values:  (61823, 2)






SOURCE: 'bool is@@ Even ( int n ) { return ( ! ( n & 1 ) ) ; }'
GENERATED: ''

----------------------------------------------------------------------------------------------------

NEXT TC TARGET: 'public'
PREDICTIONS: ['static', 'static', 'static', 'public', 'static']
DISTANCES: [103 107 107 110 113]

----------------------------------------------------------------------------------------------------

INPUT SOURCE: 'bool is@@ Prime ( int n ) { if ( n <= 1 ) return false ; for ( int i = 2 ; i < n ; i ++ ) if ( n % i == 0 ) return false ; return true ; } </s>'
INPUT TARGET: ''






SOURCE: 'bool is@@ Even ( int n ) { return ( ! ( n & 1 ) ) ; }'
GENERATED: 'static'

----------------------------------------------------------------------------------------------------

NEXT TC TARGET: 'boolean'
PREDICTIONS: ['boolean', 'boolean', 'bo