# Prediction using the pre-learned Transformer

(C) Maxim Gansert, Mindscan Engineering, 2020


In [None]:
import sys
sys.path.insert(0,'../../src')

import tensorflow as tf

from com.github.c2nes.javalang import tokenizer as tokenizer

from de.mindscan.fluentgenesis.bpe.bpe_model import BPEModel
from de.mindscan.fluentgenesis.bpe.bpe_encoder_decoder import SimpleBPEEncoder

from de.mindscan.fluentgenesis.transformer import TfTransformerV1


In [None]:
bpe_model = BPEModel("16K-full", "../../src/de/mindscan/fluentgenesis/bpe/")
bpe_model.load_hparams()

bpe_model_vocabulary = bpe_model.load_tokens()
bpe_model_bpe_data = bpe_model.load_bpe_pairs()

bpe_encoder = SimpleBPEEncoder(bpe_model_vocabulary, bpe_model_bpe_data)

In [None]:
def create_padding_mask(seq):
    # this will create a mask from the input, whereever the input is Zero, it is treated as a padding.
    # and a one is written to the result, otherwise a Zero is written to the array (where true -> '1.0': else '0.0')
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
  
    # Mask has dimensions (batchsize, 1,1, seq_len)
    return seq[:, tf.newaxis, tf.newaxis, :]

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)


def create_masks(inp, tar):
    # encoder padding mask
    enc_padding_mask = create_padding_mask(inp)
    
    # wird im second attentionblock im decoder benutzt, um den input zu maskieren
    dec_padding_mask = create_padding_mask(inp)
    
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
    
    return enc_padding_mask, combined_mask, dec_padding_mask

In [None]:
START_TOKEN = 16273
PAD_TOKEN = 0
MAX_OUTPUTLENGTH = 64

def tokenize_java_code(theSource: str):
    tokens = list(tokenizer.tokenize(theSource, ignore_errors=True))
    tokenvalues = [x.value for x in tokens]
    
    return tokenvalues

def plot_prediction_attention_weights(attention, sentence, result, layer):
    fig = plt.figure(figsize=(16, 8))
 
    attention = tf.squeeze(attention[layer], axis=0)
 
    for head in range(attention.shape[0]):
        ax = fig.add_subplot(2, 4, head+1)
 
        # plot the attention weights
        ax.matshow(attention[head][:-1, :], cmap='viridis')
 
        fontdict = {'fontsize': 10}
 
        ax.set_xticks(range(len(sentence)+2))
        ax.set_yticks(range(len(result)))
 
        ax.set_ylim(len(result)-1.5, -0.5)
 
        ax.set_xticklabels(
            ['<start>']+[bpe_encoder.decode([i]) for i in sentence]+['<end>'], 
            fontdict=fontdict, rotation=90)
 
        ax.set_yticklabels([bpe_encoder.decode([i]) for i in result 
                            if i < START_TOKEN], 
                           fontdict=fontdict)
 
        ax.set_xlabel('Head {}'.format(head+1))
 
    plt.tight_layout()
    plt.show()

#
# the following function shall sample the first line of a method, from a transformer
# greedy decoder...
#
def sample_transformer_nextline(transformer, class_name, method_name, method_signature):
    # we encode the class_name
    # we encode the method signature
    # we encode the line context
    input_tokens = [START_TOKEN] + bpe_encoder.encode([class_name,'.',method_name]) + bpe_encoder.encode( tokenize_java_code( method_signature)) + [PAD_TOKEN] + [PAD_TOKEN]
    encoderinput = tf.expand_dims(input_tokens,0)
    
    # this are the output tokens
    output_tokens = []
    # add start token to output_tokens
    output_tokens.append(START_TOKEN)
    output = tf.expand_dims(output_tokens,0)
    
    for _ in range(MAX_OUTPUTLENGTH):
        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoderinput,output)
        
        predictions, attention_weights = transformer(encoderinput,
                                                     output, 
                                                     False,
                                                     enc_padding_mask,
                                                     combined_mask,
                                                     dec_padding_mask
                                                     )
        predictions = predictions[:,-1, :]
        # greedy sampling
        predicted_id = tf.argmax(predictions, axis=-1, output_type=tf.int32)
        
        if predicted_id == PAD_TOKEN:
            return tf.squeeze(output, axis=0), attention_weights, input_tokens
        
        output = tf.concat( [output, [predicted_id]], axis=-1 )
        
        # because of a lack of a good end of sentence symbol use ";" or "}" as a trigger to detect "end of sentence"
        # this is no good solution, but well, it is a solution right now, even with a messy dataset.
        if predicted_id in (60,126):
            return tf.squeeze(output, axis=0), attention_weights, input_tokens
        
    return tf.squeeze(output, axis=0), attention_weights, input_tokens

In [None]:
def predict_first_line(transformer, class_name, method_name, method_signature, plot=''):
    result, attention_weights, input_tokens = sample_transformer_nextline(transformer, class_name, method_name, method_signature)

    result = [i for i in result.numpy() if ((i > 0) and (i<16273))]
    input_tokens = [i for i in input_tokens if ((i > 0) and (i<16273))]
    
    predicted_line = bpe_encoder.decode(result)
    decoded_input_tokens = bpe_encoder.decode(input_tokens)
    
    print ('Input Context: {}'.format(decoded_input_tokens))
    print ('Predicted output: {}'.format(predicted_line))
    
    if plot:
        # plot the attention weights
        plot_prediction_attention_weights(attention_weights, input_tokens, result, plot) 
    
    pass


In [None]:
transformer_restored = TfTransformerV1.Transformer(
    num_layers=4, d_model=256, num_heads=8, dff=1024,
    input_vocab_size=16274, target_vocab_size=16274,
    pe_input=512, pe_target=512,
    rate=0.0
    )


In [None]:
transformer_restored.load_weights(filepath='../../data/checkpoints/nextlineofcode_by_context/v3/tf')

In [None]:
predict_first_line(transformer_restored, 'Config', 'getInstance', 'Config getInstance()')
predict_first_line(transformer_restored, 'Config', 'getInstance', 'Config getInstance() if (Config.instance == null) {')

In [None]:
bpe_encoder.encode([';', '}'])