# Interactive Prediction using shared embeddings

(C) Maxim Gansert, Mindscan, 2020

In [None]:
import sys
sys.path.insert(0,'../../src')

## Outline of our Target

* load the embeddings
* instantiate and load a pretrained transformer model
* provide multiple interactive boxes
  * classname
  * method name
  * method signature
  * context
  * current line
  
* implement an interactive predictor, which can be queried to provide the next tokens
* also implement a filter mechanism so that bpe tokens can be filtered to only ones matching the input... so these can be sampled - kind of subword input / subword start - tree search?

## Performance improvements can be gained

* by caching the masked multihead attention calculations for each layer, for each input (self attention)
* by caching the multihead attention calculatuion connected to the encoder K,V

## Performance improvements

* the model might be also optimized with a prune and quantization model optimizer, so it performs faster, if zeros are found

# Support Code

In [None]:
import ipywidgets as widgets

In [None]:
from com.github.c2nes.javalang import tokenizer as tokenizer

### Load the BPE Encodings and the BPE-Support

First we need to load our BPE-Model, for encoding all the java tokens.

In [None]:
from de.mindscan.fluentgenesis.bpe.bpe_model import BPEModel
from de.mindscan.fluentgenesis.bpe.bpe_encoder_decoder import SimpleBPEEncoder

In [None]:
SYMBOL_PAD = 0
SYMBOL_START = 16273
SYMBOL_EOS = 16274

In [None]:
bpemodel = BPEModel("16K-full", "../../src/de/mindscan/fluentgenesis/bpe/")
bpemodel.load_hparams()
bpemodel_vocabulary = bpemodel.load_tokens()
bpemodel_bpe_data = bpemodel.load_bpe_pairs()

Extend the vocabulary which was used during the encoding.

In [None]:
# padding
bpemodel_vocabulary['<PAD>'] = SYMBOL_PAD
# start symbol
bpemodel_vocabulary['<START>'] = SYMBOL_START
# end of sentence
bpemodel_vocabulary['<EOS>'] = SYMBOL_EOS


In [None]:
bpe_encoder = SimpleBPEEncoder(bpemodel_vocabulary, bpemodel_bpe_data)

MODEL_VOCABULARY_LENGTH = len(bpemodel_vocabulary)

### Load the Transformer model using the checkpoints

In [None]:
import time
import numpy as np
import tensorflow as tf

from de.mindscan.fluentgenesis.transformer import TfTransformerV2

MAX_OUTPUTLENGTH = 120

In [None]:
transformer = TfTransformerV2.Transformer(
    num_layers=4, d_model=256, num_heads=8, dff=1024,
    input_vocab_size=16275, target_vocab_size=16275,
    pe_input=512, pe_target=512,
    rate=0.0
    )

transformer.load_weights(filepath='../../data/checkpoints/nextlineofcode_s_emb/v5/tf')

In [None]:
def create_padding_mask(seq):
    # this will create a mask from the input, whereever the input is Zero, it is treated as a padding.
    # and a one is written to the result, otherwise a Zero is written to the array (where true -> '1.0': else '0.0')
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
  
    # Mask has dimensions (batchsize, 1,1, seq_len)
    return seq[:, tf.newaxis, tf.newaxis, :]

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)


def create_masks(inp, tar):
    # encoder padding mask
    enc_padding_mask = create_padding_mask(inp)
    
    # wird im second attentionblock im decoder benutzt, um den input zu maskieren
    dec_padding_mask = create_padding_mask(inp)
    
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
    
    return enc_padding_mask, combined_mask, dec_padding_mask

In [None]:

def sample_transformer_nextline(transformer, input_tokens, output_tokens):
    encoderinput = tf.expand_dims(input_tokens,0)
    output = tf.expand_dims(output_tokens,0)

    for _ in range(MAX_OUTPUTLENGTH):
        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoderinput,output)
        
        predictions, attention_weights = transformer(encoderinput,
                                                     output, 
                                                     False,
                                                     enc_padding_mask,
                                                     combined_mask,
                                                     dec_padding_mask
                                                     )
        predictions = predictions[:,-1, :]
        # greedy sampling
        # predicted_id = tf.argmax(predictions, axis=-1, output_type=tf.int32)
        predicted_id = tf.random.categorical(logits=predictions, num_samples=1)[0]
        
        if predicted_id in(SYMBOL_PAD, SYMBOL_EOS):
            return tf.squeeze(output, axis=0), attention_weights, input_tokens
        
        output = tf.concat( [output, [predicted_id]], axis=-1 )
        
    return tf.squeeze(output, axis=0), attention_weights, input_tokens

### Prediction Code

In [None]:
def tokenize_java_code(theSource: str):
    tokens = list(tokenizer.tokenize(theSource, ignore_errors=True))
    tokenvalues = [x.value for x in tokens]
    
    return tokenvalues

def predict_line(transformer , class_name, method_name, method_signature, context, current_line):
    # build the input token list
    input_tokens = []
    input_tokens.append( SYMBOL_START )
    input_tokens.extend( bpe_encoder.encode( [ class_name, '.', method_name ] ) )
    input_tokens.extend( bpe_encoder.encode( tokenize_java_code( method_signature ) ) )
    # use the previous lines as context
    input_tokens.extend( bpe_encoder.encode( tokenize_java_code( context ) ) )
    # complete the input token list with the end of sentence symbol
    input_tokens.append( SYMBOL_EOS )
    
    # the last Java token may not be complete, if space or symbol, then the input is complete
    # otherwise we should mark the last tokens as not complete and use the name as a "preference".
    output_tokens = []
    output_tokens.append( SYMBOL_START )
    output_tokens.extend( bpe_encoder.encode( tokenize_java_code(current_line ) ) )
    
    result,_,_ = sample_transformer_nextline(transformer, input_tokens, output_tokens)
    result = result.numpy()
    
    return bpe_encoder.decode(result)


# Interactive Part

In [None]:
hLayout = widgets.Layout(width='80%')
vLayout = widgets.Layout(width='80%', height='150px')

outputTextArea = widgets.Textarea(description='prediction(s):', layout=vLayout)

In [None]:

classnameInputTextField = widgets.Text(
    value='', 
    placeholder='class name goes here',
    description='String:',
    disabled=False, 
    layout=hLayout
)

methodnameInputTextField = widgets.Text(
    value='',
    placeholder='method name goes here',
    description='String:',
    disabled=False, 
    layout=hLayout
)

methodsignatureInputTextField = widgets.Text(
    value='',
    placeholder='method signature goes here',
    description='String:',
    disabled=False, 
    layout=hLayout
)

methodcontextInputTextArea = widgets.Textarea(
    description='Context:',
    layout=vLayout
)

currentLineInputTextField = widgets.Textarea(
    value='',
    placeholder='current line',
    description='CurrentLine:',
    disabled=False,
    layout=hLayout
)

In [None]:
def currentLineHandler(obj):
    global transformer
    updated_line = obj.new
    
    # update the context, such
    if updated_line.endswith('\n'):
        methodcontextInputTextArea.value = methodcontextInputTextArea.value + updated_line
        currentLineInputTextField.value = ''
        return
    
    # encode each thing and run the predictor with the classname, the methodname, the signature,
    class_name = classnameInputTextField.value
    method_name = methodnameInputTextField.value
    method_signature = methodsignatureInputTextField.value
    context = methodcontextInputTextArea.value
    current_line = currentLineInputTextField.value
    
    predicted_tokens = predict_line(transformer , class_name, method_name, method_signature, context, current_line)
    
    outputTextArea.value = ' '.join(predicted_tokens)
    
currentLineInputTextField.observe(currentLineHandler, names='value')

In [None]:
display (classnameInputTextField)
display (methodnameInputTextField)
display (methodsignatureInputTextField)
display (methodcontextInputTextArea)
display (currentLineInputTextField)

display (outputTextArea)