In [1]:
import os
from transformers import RobertaTokenizer
from tqdm.notebook import tqdm

In [2]:
os.getcwd()

'/data'

# Data Loading

In [3]:
class Example(object):
    """A single training/test example."""
    def __init__(self,
                 source,
                 target,
                 lang
                 ):
        self.source = source
        self.target = target
        self.lang=lang

In [4]:
def read_examples(filename):
    """Read examples from filename."""
    examples=[]
    source,target=filename.split(',')
    lang='java'
    if source[-1]=='s':
        lang='c_sharp'
    with open(source,encoding="utf-8") as f1,open(target,encoding="utf-8") as f2:
        for idx, (line1,line2) in enumerate(zip(f1,f2)):
            line1=line1.strip()
            line2=line2.strip()
            examples.append(
                Example(
                    source=line1,
                    target=line2,
                    lang=lang
                        ) 
            )
            if idx == 5:
                break

    return examples

In [129]:
path_to_samples = "GraphCodeBERT/translation/data/train.java-cs.txt.java,GraphCodeBERT/translation/data/train.java-cs.txt.cs"
examples = read_examples(path_to_samples)[:1]

# Data Processing

In [6]:
from GraphCodeBERT.translation.parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
from GraphCodeBERT.translation.parser import (remove_comments_and_docstrings,
                   tree_to_token_index,
                   index_to_code_token,
                   tree_to_variable_index)
from tree_sitter import Language, Parser

In [101]:
parsers={}        
dfg_function={
    'python':DFG_python,
    'java':DFG_java,
    'ruby':DFG_ruby,
    'go':DFG_go,
    'php':DFG_php,
    'javascript':DFG_javascript,
    'c_sharp':DFG_csharp,
}

for lang in dfg_function:
    LANGUAGE = Language('GraphCodeBERT/translation/parser/my-languages.so', lang)
    parser = Parser()
    parser.set_language(LANGUAGE) 
    parser = [parser,dfg_function[lang]]    
    parsers[lang]= parser

In [8]:
tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base")

In [123]:
#remove comments, tokenize code and extract dataflow     
def extract_dataflow(code, parser,lang):
    #remove comments
    try:
        code=remove_comments_and_docstrings(code,lang)
    except:
        pass    
    #obtain dataflow
    if lang=="php":
        code="<?php"+code+"?>"    
    try:
        tree = parser[0].parse(bytes(code,'utf8'))    
        root_node = tree.root_node  
        tokens_index=tree_to_token_index(root_node)     
        code=code.split('\n')
        code_tokens=[index_to_code_token(x,code) for x in tokens_index]  
        index_to_code={}
        for idx,(index,code) in enumerate(zip(tokens_index,code_tokens)):
            index_to_code[index]=(idx,code)
        # index_to_code = (  
        # ((0, 0), (0, 6)): (0, 'public'), 
        # ((0, 7), (0, 37)): (1, 'ListSpeechSynthesisTasksResult'), 
        # ((0, 38), (0, 62)): (2, 'listSpeechSynthesisTasks'), ...
        # )
        try:
            DFG,_=parser[1](root_node,index_to_code,{})
        except:
            DFG=[]
        DFG=sorted(DFG,key=lambda x:x[1])
        indexs=set()
        for d in DFG:
            if len(d[-1])!=0:
                indexs.add(d[1])
            for x in d[-1]:
                indexs.add(x)
        new_DFG=[]
        for d in DFG:
            if d[1] in indexs:
                new_DFG.append(d)
        print(DFG)
        dfg=new_DFG
    except:
        dfg=[]
    return code_tokens,dfg

In [119]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

class InputFeatures(object):
    """A single training/test features for a example."""
    def __init__(self,
                 example_id,
                 source_ids,
                 position_idx,
                 dfg_to_code,
                 dfg_to_dfg,                 
                 target_ids,
                 source_mask,
                 target_mask,

    ):
        self.example_id = example_id
        self.source_ids = source_ids
        self.position_idx = position_idx
        self.dfg_to_code = dfg_to_code
        self.dfg_to_dfg = dfg_to_dfg
        self.target_ids = target_ids
        self.source_mask = source_mask
        self.target_mask = target_mask  

def convert_examples_to_features(examples, tokenizer, args,stage=None):
    features = []
    for example_index, example in enumerate(tqdm(examples,total=len(examples))):
        ##extract data flow
        code_tokens,dfg=extract_dataflow(example.source,
                                         parsers["c_sharp" if args.source_lang == "cs" else "java"],
                                         "c_sharp" if args.source_lang == "cs" else "java")
        print(code_tokens)
        code_tokens=[tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for idx,x in enumerate(code_tokens)]
        ori2cur_pos={}
        ori2cur_pos[-1]=(0,0)
        for i in range(len(code_tokens)):
            ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i]))    
        code_tokens=[y for x in code_tokens for y in x]  
        
        #truncating
        code_tokens=code_tokens[:args.max_source_length-3][:512-3]
        source_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token]
        source_ids =  tokenizer.convert_tokens_to_ids(source_tokens)
        position_idx = [i+tokenizer.pad_token_id + 1 for i in range(len(source_tokens))]
        dfg=dfg[:args.max_source_length-len(source_tokens)]
        source_tokens+=[x[0] for x in dfg]
        position_idx+=[0 for x in dfg]
        source_ids+=[tokenizer.unk_token_id for x in dfg]
        padding_length=args.max_source_length-len(source_ids)
        position_idx+=[tokenizer.pad_token_id]*padding_length
        source_ids+=[tokenizer.pad_token_id]*padding_length      
        source_mask = [1] * (len(source_tokens))
        source_mask+=[0]*padding_length        
        
        #reindex
        reverse_index={}
        for idx,x in enumerate(dfg):
            reverse_index[x[1]]=idx
        for idx,x in enumerate(dfg):
            dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],)    
        dfg_to_dfg=[x[-1] for x in dfg]
        dfg_to_code=[ori2cur_pos[x[1]] for x in dfg]
        length=len([tokenizer.cls_token])
        dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code]        
        

        #target
        if stage=="test":
            target_tokens = tokenizer.tokenize("None")
        else:
            target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2]
        target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token]            
        target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
        target_mask = [1] *len(target_ids)
        padding_length = args.max_target_length - len(target_ids)
        target_ids+=[tokenizer.pad_token_id]*padding_length
        target_mask+=[0]*padding_length   
       
        features.append(
            InputFeatures(
                 example_index,
                 source_ids,
                 position_idx,
                 dfg_to_code,
                 dfg_to_dfg,
                 target_ids,
                 source_mask,
                 target_mask,
            )
        )
    return features


In [11]:
args = dotdict({
    'source_lang': 'java',
    'max_source_length': 200,
    'max_target_length': 200
})

In [12]:
args.source_lang

'java'

In [131]:
examples[0].source

'public ListSpeechSynthesisTasksResult listSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request) {request = beforeClientExecution(request);return executeListSpeechSynthesisTasks(request);}'

In [130]:
features = convert_examples_to_features(examples, tokenizer, args)

  0%|          | 0/1 [00:00<?, ?it/s]

[('ListSpeechSynthesisTasksResult', 1, 'comesFrom', [], []), ('listSpeechSynthesisTasks', 2, 'comesFrom', [], []), ('', 3, 'comesFrom', [], []), ('ListSpeechSynthesisTasksRequest', 5, 'comesFrom', [], []), ('request', 6, 'comesFrom', [], []), ('request', 9, 'computedFrom', ['beforeClientExecution'], [11]), ('request', 9, 'computedFrom', ['request'], [13]), ('beforeClientExecution', 11, 'comesFrom', [], []), ('request', 13, 'comesFrom', ['request'], [6]), ('executeListSpeechSynthesisTasks', 17, 'comesFrom', [], []), ('request', 19, 'comesFrom', ['request'], [9])]
[['public'], ['ĠList', 'Spe', 'ech', 'Sy', 'nt', 'hesis', 'T', 'asks', 'Result'], ['Ġlist', 'Spe', 'ech', 'Sy', 'nt', 'hesis', 'T', 'asks'], ['Ġ'], ['Ġ('], ['ĠList', 'Spe', 'ech', 'Sy', 'nt', 'hesis', 'T', 'asks', 'Request'], ['Ġrequest'], ['Ġ)'], ['Ġ{'], ['Ġrequest'], ['Ġ='], ['Ġbefore', 'Client', 'Exec', 'ution'], ['Ġ('], ['Ġrequest'], ['Ġ)'], ['Ġ;'], ['Ġreturn'], ['Ġexecute', 'List', 'Spe', 'ech', 'Sy', 'nt', 'hesis', 'T', '

In [16]:
features

[<__main__.InputFeatures at 0x7fb56ba910a0>,
 <__main__.InputFeatures at 0x7fb56ba911f0>,
 <__main__.InputFeatures at 0x7fb56ba91070>,
 <__main__.InputFeatures at 0x7fb56ba91040>,
 <__main__.InputFeatures at 0x7fb56ba6ca00>,
 <__main__.InputFeatures at 0x7fb56ba6ceb0>]

In [17]:
features[0]

<__main__.InputFeatures at 0x7fb56ba910a0>

In [28]:
print(features[0].example_id)


0


In [27]:
print(features[0].source_ids)


[0, 15110, 9527, 29235, 7529, 35615, 3999, 35571, 565, 40981, 48136, 889, 29235, 7529, 35615, 3999, 35571, 565, 40981, 1437, 36, 9527, 29235, 7529, 35615, 3999, 35571, 565, 40981, 45589, 2069, 4839, 25522, 2069, 5457, 137, 47952, 46891, 15175, 36, 2069, 4839, 25606, 671, 11189, 36583, 29235, 7529, 35615, 3999, 35571, 565, 40981, 36, 2069, 4839, 25606, 35524, 2, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [29]:
print(features[0].position_idx)

[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [30]:
print(features[0].dfg_to_code)

[(30, 31), (33, 34), (33, 34), (35, 39), (40, 41), (54, 55)]


In [31]:
print(features[0].dfg_to_dfg)

[[], [3], [4], [], [0], [2]]


In [32]:
print(features[0].target_ids)

[0, 15110, 6229, 9527, 29235, 7529, 35615, 3999, 35571, 565, 40981, 47806, 9527, 29235, 7529, 35615, 3999, 35571, 565, 40981, 1640, 36583, 29235, 7529, 35615, 3999, 35571, 565, 40981, 45589, 2069, 48512, 10806, 1735, 5457, 92, 9318, 5361, 47261, 47006, 45012, 4, 45589, 40825, 1250, 254, 5457, 9527, 29235, 7529, 35615, 3999, 35571, 565, 40981, 45589, 40825, 1250, 254, 4, 49483, 131, 45012, 4, 47806, 9685, 119, 14980, 1250, 254, 5457, 9527, 29235, 7529, 35615, 3999, 35571, 565, 40981, 47806, 9685, 119, 14980, 1250, 254, 4, 49483, 131, 30921, 9318, 5361, 41552, 36583, 29235, 7529, 35615, 3999, 35571, 565, 40981, 47806, 49925, 16604, 6, 1735, 4397, 24303, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [33]:
print(features[0].source_mask)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [34]:
print(features[0].target_mask)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


# Checking target_mask and source_mask

In [36]:
sum(features[0].target_mask)

108

In [40]:
len(tokenizer(examples[0].target).input_ids)

108

In [41]:
sum(features[0].source_mask)

65

In [42]:
len(tokenizer(examples[0].source).input_ids)

56

In [52]:
len(tokenizer.tokenize(examples[0].source))

54

In [56]:
tokenizer.unk_token_id

3

# Check extract dataflow

In [92]:
sample_input = examples[1].source
sample_input

'public UpdateJourneyStateResult updateJourneyState(UpdateJourneyStateRequest request) {request = beforeClientExecution(request);return executeUpdateJourneyState(request);}'

In [93]:
code_tokens, dfg = extract_dataflow(sample_input, parsers["java"], "java")

In [94]:
code_tokens

['public',
 'UpdateJourneyStateResult',
 'updateJourneyState',
 '',
 '(',
 'UpdateJourneyStateRequest',
 'request',
 ')',
 '{',
 'request',
 '=',
 'beforeClientExecution',
 '(',
 'request',
 ')',
 ';',
 'return',
 'executeUpdateJourneyState',
 '(',
 'request',
 ')',
 ';',
 '}']

In [76]:
dfg

[('request', 6, 'comesFrom', [], []),
 ('request', 9, 'computedFrom', ['beforeClientExecution'], [11]),
 ('request', 9, 'computedFrom', ['request'], [13]),
 ('beforeClientExecution', 11, 'comesFrom', [], []),
 ('request', 13, 'comesFrom', ['request'], [6]),
 ('request', 19, 'comesFrom', ['request'], [9])]

In [63]:
temp = [tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for idx,x in enumerate(code_tokens)]
temp=[y for x in temp for y in x]  

In [65]:
len(temp)

57

In [66]:
57 + 2 + 6

65

In [71]:
code_tokens
for i in (code_tokens):
    print(len(i))

6
30
24
0
1
31
7
1
1
7
1
21
1
7
1
1
6
31
1
7
1
1
1


In [68]:
ori2cur_pos={}
ori2cur_pos[-1]=(0,0)
for i in range(len(code_tokens)):
    print(ori2cur_pos[i-1][1])
    print(ori2cur_pos[i-1][1]+len(code_tokens[i]))
    print()
    ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i]))

0
6

6
36

36
60

60
60

60
61

61
92

92
99

99
100

100
101

101
108

108
109

109
130

130
131

131
138

138
139

139
140

140
146

146
177

177
178

178
185

185
186

186
187

187
188



In [72]:
ori2cur_pos

{-1: (0, 0),
 0: (0, 6),
 1: (6, 36),
 2: (36, 60),
 3: (60, 60),
 4: (60, 61),
 5: (61, 92),
 6: (92, 99),
 7: (99, 100),
 8: (100, 101),
 9: (101, 108),
 10: (108, 109),
 11: (109, 130),
 12: (130, 131),
 13: (131, 138),
 14: (138, 139),
 15: (139, 140),
 16: (140, 146),
 17: (146, 177),
 18: (177, 178),
 19: (178, 185),
 20: (185, 186),
 21: (186, 187),
 22: (187, 188)}

In [73]:
len(code_tokens)

23

In [84]:
reverse_index={}
temp_dfg=dfg.copy()
for idx,x in enumerate(temp_dfg):
    reverse_index[x[1]]=idx
for idx,x in enumerate(temp_dfg):
    temp_dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],)

In [78]:
reverse_index

{6: 0, 9: 2, 11: 3, 13: 4, 19: 5}

In [82]:
dfg

[('request', 6, 'comesFrom', [], []),
 ('request', 9, 'computedFrom', ['beforeClientExecution'], [11]),
 ('request', 9, 'computedFrom', ['request'], [13]),
 ('beforeClientExecution', 11, 'comesFrom', [], []),
 ('request', 13, 'comesFrom', ['request'], [6]),
 ('request', 19, 'comesFrom', ['request'], [9])]

In [85]:
temp_dfg

[('request', 6, 'comesFrom', [], []),
 ('request', 9, 'computedFrom', ['beforeClientExecution'], [3]),
 ('request', 9, 'computedFrom', ['request'], [4]),
 ('beforeClientExecution', 11, 'comesFrom', [], []),
 ('request', 13, 'comesFrom', ['request'], [0]),
 ('request', 19, 'comesFrom', ['request'], [2])]

In [86]:
dfg_to_dfg=[x[-1] for x in temp_dfg]
dfg_to_code=[ori2cur_pos[x[1]] for x in temp_dfg]
length=len([tokenizer.cls_token])
# dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code]

In [87]:
dfg_to_dfg

[[], [3], [4], [], [0], [2]]

In [88]:
dfg_to_code

[(92, 99), (101, 108), (101, 108), (109, 130), (131, 138), (178, 185)]

In [89]:
[(x[0]+length,x[1]+length) for x in dfg_to_code]

[(93, 100), (102, 109), (102, 109), (110, 131), (132, 139), (179, 186)]