In [4]:
import re
import os
import sys
import json
import random
import subprocess
from tqdm import tqdm

def retrieve_file(repo_path: str, file_path: str, commit_sha: str):
    '''
    Func: Retrieve file of a given version
    Args:
        repo_path: str, the repository directory
        file_path: str, the relative path for the file in 
                        terms of the repository
        commit_sha: the commit version
    '''
    # Save the current working directory
    original_dir = os.getcwd()

    try:
        # Change to the repository directory
        os.chdir(repo_path)

        # Check out the specific commit
        checkout_command = f'git checkout -q {commit_sha}'
        subprocess.run(checkout_command, shell=True, check=True)

        # Copy the file to the desired location
        source_file = os.path.normpath(file_path)
        with open(source_file, 'r', encoding='utf-8') as f:
            content = f.readlines()

        # Always return to the original working directory
        os.chdir(original_dir)
        return content
    except:
        # Always return to the original working directory
        os.chdir(original_dir)
        raise KeyError('Unable to find the file.')
    
def extract_info(old_file_path, new_file_path):
    file_path = '/'.join(old_file_path.split('/')[3:])
    elements = old_file_path.split('/')[2].split('_')
    user_name = elements[0]
    old_sha = elements[-1]
    proj_name = '_'.join(elements[1:-1])
    new_sha = new_file_path.split('/')[2].split('_')[-1]

    return user_name, proj_name, old_sha, new_sha, file_path

def contains_non_english_content(content):
    if type(content) == list:
        content = ''.join(content)
    pattern = re.compile(r'[^\x00-\x7F]+')
    match = pattern.search(content)
    
    if match:
        return True
    else:
        return False
    
def apply_filter(data_sample):
    # 1. if code window is empty, return False
    if len(data_sample['code_window']) == 0:
        return False
    # 2. if hunk type is 'add', but add_line is empty, return False
    if data_sample['hunk_type'] == 'add' and len(data_sample['add_line']) == 0:
        return False
    # 3. if code_window, add_line, commit message contains more than english words, return False
    if contains_non_english_content(data_sample['commit_msg']) or \
        contains_non_english_content(data_sample['add_line']) or \
        contains_non_english_content(data_sample['code_window']):
        return False

    return True

In [5]:
# 1. Load raw dataset int jsonl format
dataset = []
with open('./dataset/java_dataset.jsonl', 'r', encoding='utf-8') as file:
    for i in file.readlines():
        data = json.loads(i)
        dataset.append(data)
print('Dataset size:',len(dataset))
print('Dataset example:\n',dataset[0])

Dataset size: 7282
Dataset example:
 {'user_name': 'Snailclimb', 'proj_name': 'JavaGuide', 'old_sha': 'b1576701bbbbfefa247e5d361e5b10f02e360e2a', 'new_sha': '789406cc28c9974334f2f07b0e85f0eb559d93d5', 'file_path': '分布式.md', 'changes': [{'func_name': '', 'del_line_idx': [], 'add_line_idx': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], 'del_line': '', 'add_line': '  - ### 分布式系统的经典基础理论\n  \n    [分布式系统的经典基础理论](https://blog.csdn.net/qq_34337272)\n  - ### 分布式事务\n    [聊聊分布式事务，再说说解决方案](http://www.cnblogs.com/savorboard/p/distributed-system-transaction-consistency.html)\n  - ### 分布式系统一致性\n    [分布式服务化系统一致性的“最佳实干”](https://www.jianshu.com/p/1156151e20c8)\n\n   - ### 一致性协议/算法\n     早在1898年就诞生了著名的 **Paxos经典算法** （**Zookeeper就采用了Paxos算法的近亲兄弟Zab算法**），但由于Paxos算法非常难以理解、实现、排错。所以不断有人尝试简化这一算法，直到2013年才有了重大突破：斯坦福的Diego Ongaro、John Ousterhout以易懂性为目标设计了新的一致性算法—— **Raft算法** ，并发布了对应的论文《In Search of an Understandable Consensus Algorithm》，到现在有十多种语言实现的Raft算法框架，较为出名的有以Go语言实现的Etcd，它的功能类似于Zookeeper，但采用了更为主流的Rest接口。\n  

In [6]:
# 2. categorize the dataset by commit_id
# the unit of datasample of the raw dataset is a changed file, with multiple changed hunks
# commit_id_dict: {commit_url: [the index of data samples in the raw dataset that belongs to this commit]}
commit_id_dict = {}
for idx, i in enumerate(dataset):
    if i['html_url'] in commit_id_dict:
        commit_id_dict[i['html_url']].append(idx)
    else:
        commit_id_dict[i['html_url']] = [idx]

print('The number of commits:' ,len(commit_id_dict))    

The number of commits: 3709


# General

In [None]:
# 3. convert each edit hunk into a data sample
add = []
replace = []
remove = []
prev_line = [3,4,5]
repos_dir = './repos'
processed_data_dir = ''

for idx, i in enumerate(tqdm(dataset)):
    cnt = 0
    if 'old_file_path' in i.keys():
        user_name, proj_name, old_sha, new_sha, file_path = extract_info(i['old_file_path'], i['new_file_path'])
        i['user_name'] = user_name
        i['proj_name'] = proj_name
        i['old_sha'] = old_sha
        i['new_sha'] = new_sha
        i['file_path'] = file_path
    for j in i['changes']:
        # add
        if j['del_line_idx'] == [] and j['add_line_idx']:
            # with open('../'+i['new_file_path'], encoding='utf-8') as file:
            proj_name = i['proj_name']
            commit_sha = i['new_sha']
            file_path = i['file_path']
            repo_path = os.path.join(repos_dir, proj_name)
            # Get the content of this file of the given commit
            try:
                lines = retrieve_file(repo_path, file_path, commit_sha)
            except KeyError:
                break
            except:
                raise KeyError('Unexpected error:', sys.exc_info()[0])

            label = ['keep']*len(lines)
            
            add_line = j['add_line_idx'][0]-1
            start_line = j['add_line_idx'][0]-random.choice(prev_line)
            mid_line = j['add_line_idx'][-1]
            end_line = mid_line+random.choice(prev_line)+1
            
            if start_line < 0:
                start_line = 0
            if end_line > len(lines):
                end_line = -1
                
            code_window = lines[start_line:add_line] + lines[mid_line:end_line]
            label_window = label[start_line:add_line] + ['add'] + label[mid_line:end_line]
            
            data_sample = {
                        'code_window': code_window, 
                        'label_window': label_window, 
                        'commit_msg': i['commit_msg'], 
                        'html_url': i['html_url'], 
                        'add_line': j['add_line'], 
                        'method_name': j['func_name'],
                        'old_file_path': file_path, 
                        'idx': idx,
                        'hunk_type': 'add'
                    }
            if apply_filter(data_sample):
                add.append(data_sample)
                
        # replace
        elif j['del_line_idx'] and j['add_line_idx']:
            # with open('../'+i['old_file_path'], encoding='utf-8') as file:
            proj_name = i['proj_name']
            commit_sha = i['old_sha']
            file_path = i['file_path']
            repo_path = os.path.join(repos_dir, proj_name)
            try:
                lines = retrieve_file(repo_path, file_path, commit_sha)
            except KeyError:
                break
            except:
                raise KeyError('Unexpected error:', sys.exc_info()[0])


            label = ['keep']*len(lines)
            for del_line in j['del_line_idx']:
                label[del_line-1] = 'replace'
            
            start_line = j['del_line_idx'][0]-random.choice(prev_line)
            end_line = j['del_line_idx'][-1]+random.choice(prev_line)
            if start_line < 0:
                start_line = 0
            if end_line > len(lines):
                end_line = -1
            code_window = lines[start_line:end_line]
            label_window = label[start_line:end_line]
            
            data_sample = {
                        'code_window': code_window, 
                        'label_window': label_window, 
                        'commit_msg': i['commit_msg'], 
                        'html_url': i['html_url'], 
                        'add_line': j['add_line'], 
                        'method_name': j['func_name'],
                        'old_file_path': file_path, 
                        'idx': idx,
                        'hunk_type': 'replace'
                    }
            if apply_filter(data_sample):
                replace.append(data_sample)
        
        # remove
        elif j['del_line_idx'] and j['add_line_idx'] == []:
            # with open('../'+i['old_file_path'], encoding='utf-8') as file:
            proj_name = i['proj_name']
            commit_sha = i['old_sha']
            file_path = i['file_path']
            repo_path = os.path.join(repos_dir, proj_name)
            try:
                lines = retrieve_file(repo_path, file_path, commit_sha)
            except KeyError:
                break
            except:
                raise KeyError('Unexpected error:', sys.exc_info()[0])
            
            label = ['keep']*len(lines)
            for del_line in j['del_line_idx']:
                label[del_line-1] = 'remove'
            
            start_line = j['del_line_idx'][0]-random.choice(prev_line)
            end_line = j['del_line_idx'][-1]+random.choice(prev_line)
            if start_line < 0:
                start_line = 0
            if end_line > len(lines):
                end_line = -1
            code_window = lines[start_line:end_line]
            label_window = label[start_line:end_line]
            
            data_sample = {
                        'code_window': code_window, 
                        'label_window': label_window, 
                        'commit_msg': i['commit_msg'], 
                        'html_url': i['html_url'], 
                        'add_line': j['add_line'], 
                        'method_name': j['func_name'],
                        'old_file_path': file_path, 
                        'idx': idx,
                        'hunk_type': 'remove'
                    }
            if apply_filter(data_sample):
                remove.append(data_sample)
            
print('The number of add type hunks:', len(add))
print('The number of replace type hunks:', len(replace))
print('The number of remove type hunks:', len(remove))

In [8]:
# 4. categorize the data samples by commit_id
# {commit url: [data samples in this commit]}
result_dict = {} 
for i in add+replace+remove:
    html_url = i['html_url']
    if html_url in result_dict:
        result_dict[html_url].append(i)
    else:
        result_dict[html_url] = [i]

# save the dict into json file
with open('./dataset/processed_dataset.json', 'w', encoding='utf-8') as file:
    json.dump(result_dict, file, indent=4, ensure_ascii=False)

# create dataset: edit generation

In [131]:
# code window + label_window + commit message + prev_change

In [144]:
# rank by similarity
from rank_bm25 import BM25Okapi
output = []
for commit in sorted(result_dict.keys()):
    for co_change in result_dict[commit]:
        code_window = ''.join(co_change['code_window'])
        label_window = ' '.join(co_change['label_window'])
        commit_message = co_change['commit_msg']
        context = []
        if len(code_window) == 0:
            continue
        
        # BM25 search for related context
        prev_edit = result_dict[commit].copy()
        prev_edit.remove(co_change)
        try:
            tokenized_corpus = [''.join(i['code_window']+[i['add_line']]).split() for i in prev_edit]
            bm25 = BM25Okapi(tokenized_corpus) # build a BM25 object with other hunks
            tokenized_query = code_window.split()
            retrieval_code = bm25.get_top_n(tokenized_query, tokenized_corpus, n=5)
            context_index = [tokenized_corpus.index(i) for i in retrieval_code] # get the index of the top 5 similar hunks

            # form context, which are the deleted and added lines in the top 5 similar hunkss
            for idx in context_index:
                if prev_edit[idx]['hunk_type'] == 'replace': 
                    replace = prev_edit[idx]['label_window'].index('replace')
                    context.append('remove '+ prev_edit[idx]['code_window'][replace])
                    context.append('add '+ prev_edit[idx]['add_line'])

                elif prev_edit[idx]['hunk_type'] == 'remove':
                    remove = prev_edit[idx]['label_window'].index('remove')
                    context.append('remove '+ prev_edit[idx]['code_window'][remove])

                elif prev_edit[idx]['label_window'] == 'add':
                    context.append('add '+ prev_edit[idx]['add_line'])
        except:
            pass
        
        input_ = ' </s> '.join([code_window, label_window, commit_message] + context)
        output_ =   co_change['add_line']
        html_url =  co_change['html_url']
        file_name = co_change['file_path']
        output.append({"docstring_tokens":output_, "code_tokens":input_, "html_url":html_url, "file_name":file_name})
 

In [150]:
len(output)

209585

In [153]:
import jsonlines
# final data format: {"docstring_tokens":doc_tokens, "code_tokens":code_tokens}
os.path.join(processed_data_dir, 'generator/train.jsonl')
with jsonlines.open(os.path.join(processed_data_dir, 'generator/train.jsonl'), 'w') as f:
    for item in output[:int(0.7*len(output))]:
        f.write(item)
with jsonlines.open(os.path.join(processed_data_dir, 'generator/dev.jsonl'), 'w') as f:
    for item in output[int(0.7*len(output)): int(0.8*len(output))]:
        f.write(item)
with jsonlines.open(os.path.join(processed_data_dir, 'generator/test.jsonl'), 'w') as f:
    for item in output[int(0.8*len(output)):]:
        f.write(item)

# create dataset: edit locator

In [None]:
# rank by similarity
from rank_bm25 import BM25Okapi
output = []
for commit in result_dict:
    for co_change in result_dict[commit]:
        code_window = ''.join(co_change['code_window'])
        label_window = ' '.join(co_change['label_window'])
        commit_message = co_change['commit_msg']
        context = []
        if len(code_window) == 0:
            continue
        
        # BM25 search for related context
        prev_edit = result_dict[commit].copy()
        prev_edit.remove(co_change)
        try:
            tokenized_corpus = [''.join(i['code_window']+[i['add_line']]).split() for i in prev_edit]
            bm25 = BM25Okapi(tokenized_corpus)
            tokenized_query = code_window.split()
            retrieval_code = bm25.get_top_n(tokenized_query, tokenized_corpus, n=5)
            context_index = [tokenized_corpus.index(i) for i in retrieval_code]

            for idx in context_index:
                if prev_edit[idx]['hunk_type'] == 'replace':
                    replace = prev_edit[idx]['label_window'].index('replace')
                    context.append('remove '+ prev_edit[idx]['code_window'][replace])
                    context.append('add '+ prev_edit[idx]['add_line'])

                elif prev_edit[idx]['hunk_type'] == 'remove':
                    remove = prev_edit[idx]['label_window'].index('remove')
                    context.append('remove '+ prev_edit[idx]['code_window'][remove])

                elif prev_edit[idx]['hunk_type'] == 'add':
                    context.append('add '+ prev_edit[idx]['add_line'])
        except:
            pass
        
        input_ = ' </s> '.join([code_window, commit_message] + context)
        output_ =   co_change['add_line']
        html_url =  co_change['html_url']
        file_name = co_change['file_path']
        output.append({"docstring_tokens":label_window, "code_tokens":input_, "html_url":html_url, "file_name":file_name})
 

In [None]:
import jsonlines
# create dataset: code generation
# final data format: {"docstring_tokens":doc_tokens, "code_tokens":code_tokens}

with jsonlines.open(os.path.join(processed_data_dir, 'locator/train.jsonl'), 'w') as f:
    for item in output[:int(0.7*len(output))]:
        f.write(item)
with jsonlines.open(os.path.join(processed_data_dir, 'locator/dev.jsonl'), 'w') as f:
    for item in output[int(0.7*len(output)): int(0.8*len(output))]:
        f.write(item)
with jsonlines.open(os.path.join(processed_data_dir, 'locator/test.jsonl'), 'w') as f:
    for item in output[int(0.8*len(output)):]:
        f.write(item)