In [78]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [79]:
# modify these paths to your system's
MODEL_METHOD_MUTATION_PATH = "/home/chenghin/Desktop/repos/java-mutation-framework/models/code-generation/saved_models/checkpoint-best-score"
MODEL_COMMENT_MUTATION_PATH = "/home/chenghin/Desktop/repos/java-mutation-framework/models/codet5_base_all_lr5_bs32_src64_trg64_pat5_e10/checkpoint-best-bleu"

In [80]:
from model import Seq2Seq # Copy pasted from https://github.com/microsoft/CodeBERT/blob/master/UniXcoder/downstream-tasks/code-generation/model.py
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, T5ForConditionalGeneration, AutoTokenizer

PRETRAINED_MODEL_NAME_UNIXCODER = "microsoft/unixcoder-base"
PRETRAINED_MODEL_NAME_CODET5 = "Salesforce/codet5-base-multi-sum"

class TokenizerModelPair:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

def build_model(pretrained_model_name):
    if pretrained_model_name == PRETRAINED_MODEL_NAME_UNIXCODER:
        # build model
        tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name)
        config = RobertaConfig.from_pretrained(pretrained_model_name)
        # important！You must set is_decoder to True for generation
        config.is_decoder = True
        encoder = RobertaModel.from_pretrained(pretrained_model_name,config=config)
        model = Seq2Seq(encoder=encoder,decoder=encoder,config=config,
                        beam_size=10,max_length=256,
                        sos_id=tokenizer.convert_tokens_to_ids(["<mask0>"])[0],eos_id=tokenizer.sep_token_id)
    else:
        # build model
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
        model = T5ForConditionalGeneration.from_pretrained(pretrained_model_name)
        model.to(device)
    return TokenizerModelPair(model, tokenizer)

In [81]:
unixcoder_model_and_tokenizer = build_model(PRETRAINED_MODEL_NAME_UNIXCODER)
codet5_model_and_tokenizer = build_model(PRETRAINED_MODEL_NAME_CODET5)

Updated!!!


In [82]:
import os
def load_model(model, load_model_path='fine_tuned_models'):
    model_to_load = model.module if hasattr(model, 'module') else model
    load_model_path = os.path.join(load_model_path, 'pytorch_model.bin')
    model.load_state_dict(torch.load(load_model_path, map_location='cpu'))
    model.to(device)

In [83]:
load_model(unixcoder_model_and_tokenizer.model, MODEL_METHOD_MUTATION_PATH)
load_model(codet5_model_and_tokenizer.model, MODEL_COMMENT_MUTATION_PATH)

In [91]:
def get_tokens

def predict(model, tokenizer, code, gold):
    input_ids = tokenizer(code, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids.to(device))
    print(generated_ids)
    comment = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return comment, getSmoothBLEU4(gold, comment)

In [85]:
import nltk
from nltk.translate.bleu_score import SmoothingFunction
chencherry = SmoothingFunction()
def getSmoothBLEU4(gold, result):
    BLEUscore = nltk.translate.bleu_score.sentence_bleu([gold.split()], result.split(), weights = [0.25,0.25,0.25,0.25], smoothing_function=chencherry.method2)
    return round(BLEUscore,4)

In [86]:
code = '''
public RequestMethodsRequestCondition getMatchingCondition(ServerWebExchange exchange) {
                if (CorsUtils.isPreFlightRequest(exchange.getRequest())) {
                        return matchPreFlight(exchange.getRequest());
                }
                if (getMethods().isEmpty()) {
                        if (RequestMethod.OPTIONS.name().equals(exchange.getRequest().getMethodValue())) {
                                return null; // We handle OPTIONS transparently, so don't match if no explicit declarations
                        }
                        return this;
                }
                return matchRequestMethod(exchange.getRequest().getMethod());
        }

'''

In [87]:
gold = 'check if any of the http request methods match the given request and return an instance that contains the matching http request method only'

In [92]:
predict(unixcoder_model_and_tokenizer.model, unixcoder_model_and_tokenizer.tokenizer, code, gold) # test method mutation

tensor([[[ 653, 5128, 5763,  ...,    0,    0,    0],
         [ 653, 5128, 5763,  ...,    0,    0,    0],
         [1164, 5763, 1164,  ...,    0,    0,    0],
         ...,
         [ 653, 5128, 5763,  ...,    0,    0,    0],
         [ 653, 5128, 5763,  ...,    0,    0,    0],
         [1164, 5763, 1164,  ...,    0,    0,    0]]], device='cuda:0')


TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'

In [93]:
predict(codet5_model_and_tokenizer.model, codet5_model_and_tokenizer.tokenizer, gold, code) # test comment mutation

tensor([[    0,     1, 32099,   430,  1281,   434,   326,  1062,   590,  2590,
           845,   326,   864,   590,   471,   327,   392,   791,   716,  1914]],
       device='cuda:0')


('if any of the http request methods match the given request and return an instance that contains',
 0.0236)

In [None]:
class ExplainableMutator:
    def __init__(self, comment_mutation_model, method_mutation_model):
        self.comment_mutation_model = comment_mutation_model
        self.method_mutation_model = method_mutation_model

    def generate(self, comment, method_body):
        mutated_comment = predict(self.comment_mutation_model, method_body, comment)[0]
        mutated_method = predict(self.method_mutation_model, method_body, mutated_comment)[0]
        return [mutated_comment, mutated_method]

In [None]:
class ExplainableMutationSocketServer(SocketServer):
    def __init__(self, host, port, explainable_mutator):
        super().__init__(host, port)
        self.explainable_mutator = explainable_mutator

    def func(self):
        while True:
            print("-"*20)
            comment = self.recvMsg()
            method_body = self.recvMsg()
            [mutated_comment, mutated_method] = self.explainable_mutator.generate(comment, method_body)
            self.sendMsg(mutated_comment)
            self.sendMsg(mutated_method)

In [None]:
HOST = "127.0.0.1"
PORT = 8080
server = ExplainableMutationSocketServer(HOST, PORT, ExplainableMutator())

In [74]:
    tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum')
    model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum')

    text = """def svg_to_image(string, size=None):
    if isinstance(string, unicode):
        string = string.encode('utf-8')
        renderer = QtSvg.QSvgRenderer(QtCore.QByteArray(string))
    if not renderer.isValid():
        raise ValueError('Invalid SVG data.')
    if size is None:
        size = renderer.defaultSize()
        image = QtGui.QImage(size, QtGui.QImage.Format_ARGB32)
        painter = QtGui.QPainter(image)
        renderer.render(painter)
    return image"""

    input_ids = tokenizer(code, return_tensors="pt").input_ids

    generated_ids = model.generate(input_ids, max_length=20)
    print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

Get matching condition.
