In [3]:
import torch
from transformers import RobertaTokenizer, T5ForConditionalGeneration
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_ = model.to(device)

In [4]:
def predict(model, code, gold):
    input_ids = tokenizer(code, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids.to(device), max_length=64)
    comment = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return comment, getSmoothBLEU4(gold, comment)

In [5]:
import nltk
from nltk.translate.bleu_score import SmoothingFunction
chencherry = SmoothingFunction()
def getSmoothBLEU4(gold, result):
    BLEUscore = nltk.translate.bleu_score.sentence_bleu([gold.split()], result.split(), weights = [0.25,0.25,0.25,0.25], smoothing_function=chencherry.method2)
    return round(BLEUscore,4)

In [5]:
code = '''
public RequestMethodsRequestCondition getMatchingCondition(ServerWebExchange exchange) {
                if (CorsUtils.isPreFlightRequest(exchange.getRequest())) {
                        return matchPreFlight(exchange.getRequest());
                }
                if (getMethods().isEmpty()) {
                        if (RequestMethod.OPTIONS.name().equals(exchange.getRequest().getMethodValue())) {
                                return null; // We handle OPTIONS transparently, so don't match if no explicit declarations
                        }
                        return this;
                }
                return matchRequestMethod(exchange.getRequest().getMethod());
        }

'''

In [6]:
gold = 'check if any of the http request methods match the given request and return an instance that contains the matching http request method only'

In [9]:
predict(model, code, gold)

('Get matching condition.', 0.0004)

In [2]:
import torch
pretrained_model_name = "microsoft/graphcodebert-base"
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel
from model import Seq2Seq
MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
config_class, model_class, tokenizer_class = MODEL_CLASSES['roberta']
config = config_class.from_pretrained(pretrained_model_name)
tokenizer = tokenizer_class.from_pretrained(pretrained_model_name)
# build model
encoder = model_class.from_pretrained(pretrained_model_name, config=config)
decoder_layer = torch.nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=6)
pretrained_model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,
                beam_size=10, max_length=64,
                sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Some weights of the model checkpoint at microsoft/graphcodebert-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at microsoft/graphcodebert-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to

In [3]:
import os
def load_model(load_model_path='fine_tuned_models'):
    # load the fine-tuned model
    if load_model_path:
        load_model_path = os.path.join(load_model_path, 'pytorch_model.bin')
        pretrained_model.load_state_dict(torch.load(load_model_path, map_location='cpu'))
    pretrained_model.to(device)
load_model()

In [10]:
def prediction(example):
    pretrained_model.eval()
    max_source_length=256
    source_tokens = tokenizer.tokenize(example)[:max_source_length]
    source_tokens =[tokenizer.cls_token]+source_tokens+[tokenizer.sep_token]
    source_ids =  tokenizer.convert_tokens_to_ids(source_tokens) 
    source_mask = [1] * (len(source_tokens))
    padding_length = max_source_length - len(source_ids)
    source_ids+=[tokenizer.pad_token_id]*padding_length
    source_mask+=[0]*padding_length
    # feed to the model and predict the result
    preds, score = pretrained_model(source_ids=torch.tensor([source_ids]).to(device), source_mask=torch.tensor([source_mask]).to(device))
    result = []
    for pred in preds:
        t = pred[0].cpu().numpy()
        t = list(t)
        if 0 in t:
            t = t[:t.index(0)]
        text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
    result.append(text)
    return result[0], score

In [11]:
code = '''
public request methods request condition get matching condition ( server web exchange exchange ) { if ( cors utils is pre flight request ( exchange get request ( ) ) ) { return match pre flight ( exchange get request ( ) ) ; } if ( get methods ( ) is empty ( ) ) { if ( request method options name ( ) equals ( exchange get request ( ) get method value ( ) ) ) { return null ; / / we handle options transparently , so don ' t match if no explicit declarations } return this ; } return match request method ( exchange get request ( ) get method ( ) ) ; }
'''

In [13]:
sample = {"docstring_tokens": ["check", "if", "any", "of", "the", "http", "request", "methods", "match", "the", "given", "request", "and", "return", "an", "instance", "that", "contains", "the", "matching", "http", "request", "method", "only"], "code_tokens": ["public", "request", "methods", "request", "condition", "get", "matching", "condition", "(", "server", "web", "exchange", "exchange", ")", "{", "if", "(", "cors", "utils", "is", "pre", "flight", "request", "(", "exchange", "get", "request", "(", ")", ")", ")", "{", "return", "match", "pre", "flight", "(", "exchange", "get", "request", "(", ")", ")", ";", "}", "if", "(", "get", "methods", "(", ")", "is", "empty", "(", ")", ")", "{", "if", "(", "request", "method", "options", "name", "(", ")", "equals", "(", "exchange", "get", "request", "(", ")", "get", "method", "value", "(", ")", ")", ")", "{", "return", "null", ";", "/", "/", "we", "handle", "options", "transparently", ",", "so", "don", "'", "t", "match", "if", "no", "explicit", "declarations", "}", "return", "this", ";", "}", "return", "match", "request", "method", "(", "exchange", "get", "request", "(", ")", "get", "method", "(", ")", ")", ";", "}"]}

In [12]:
prediction(code)

  prevK = bestScoresId // numWords


('checks if any of the http request methods match the given request and returns a matching { @ link request methods request condition } if no match is defined , return { @ code null } otherwise',
 tensor(-14.6903, device='cuda:0'))

In [16]:
' '.join(sample['code_tokens'])

"public request methods request condition get matching condition ( server web exchange exchange ) { if ( cors utils is pre flight request ( exchange get request ( ) ) ) { return match pre flight ( exchange get request ( ) ) ; } if ( get methods ( ) is empty ( ) ) { if ( request method options name ( ) equals ( exchange get request ( ) get method value ( ) ) ) { return null ; / / we handle options transparently , so don ' t match if no explicit declarations } return this ; } return match request method ( exchange get request ( ) get method ( ) ) ; }"

In [15]:
prediction(' '.join(sample['code_tokens']))

('checks if any of the configured http request methods match the given request and returns a { @ link request methods request condition } if no match is found , return { @ code null } otherwise',
 tensor(-15.1087, device='cuda:0'))