In [83]:
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm

tqdm.pandas()

### Load CodeNetSearch Dataset and Append Relevance Scores

In [84]:
relevance_df = pd.read_csv('data/annotationStore.csv')
relevance_df = relevance_df[relevance_df['Language'] == 'Java']
relevance_df.head()

Unnamed: 0,Language,Query,GitHubUrl,Relevance,Notes
166,Java,fuzzy match ranking,https://github.com/spotbugs/spotbugs/blob/f636...,0,
167,Java,create cookie,https://github.com/apache/spark/blob/25ee0474f...,2,
168,Java,parse query string in url,https://github.com/tanhaichao/leopard-lang/blo...,0,
169,Java,convert int to string,https://github.com/hankcs/HanLP/blob/a538d0722...,0,
170,Java,deducting the median from each column,https://github.com/datacleaner/AnalyzerBeans/b...,0,


In [85]:
# ds_train = load_dataset("code_search_net", "java", split='train+test+validation')
ds_train = load_dataset("code_search_net", "java", split='train[:1%]')
ds_train

Dataset({
    features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
    num_rows: 4545
})

In [86]:
def get_relevance(repo_url):
    row = relevance_df[relevance_df['GitHubUrl'] == repo_url]

    if not row.empty:
        cols = tuple(row.iloc[0][['Query', 'Relevance']])
        return cols[0], cols[1]

    return None, None

# Used to check of docstring is written in a different language other than English.
def is_ascii(s):
    return all(ord(c) < 128 for c in s)

In [87]:
get_relevance('https://github.com/spring-projects/spring-boot/blob/0b27f7c70e164b2b1a96477f1d9c1acba56790c1/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/info/GitProperties.java#L106-L118')

('convert a utc time to epoch', 2)

In [88]:
relevance_scores = []
queries = []

for i, row in tqdm(enumerate(iter(ds_train))):
    try:
        if not is_ascii(row['func_documentation_string']):
            relevance_scores.append(None)
            queries.append(None)
            continue
    except StopIteration:
        break

    query, score = get_relevance(row['func_code_url'])

    relevance_scores.append(score)
    queries.append(query)

assert len(relevance_scores) == len(ds_train)
assert any(relevance_scores) is not None

4545it [00:00, 4668.69it/s]


In [89]:
for a in relevance_scores:
    if a is not None:
        print(a)

3


In [90]:
len(relevance_scores)

4545

In [91]:
ds_train = ds_train.add_column("label", relevance_scores)
ds_train = ds_train.add_column("query", queries)
ds_train

Dataset({
    features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url', 'label', 'query'],
    num_rows: 4545
})

In [92]:
ds_train = ds_train.remove_columns(['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_url', 'split_name', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens'])
ds_train

Dataset({
    features: ['func_code_string', 'label', 'query'],
    num_rows: 4545
})

In [93]:
ds_train_filtered = ds_train.filter(lambda scored: scored['label'] is not None)
ds_train_filtered

Filter:   0%|          | 0/4545 [00:00<?, ? examples/s]

Dataset({
    features: ['func_code_string', 'label', 'query'],
    num_rows: 1
})

In [94]:
# Taken from CodeBERT Preprocessing steps

def format_str(string):
    for char in ['\r\n', '\r', '\n']:
        string = string.replace(char, ' ')
    return string

In [95]:
def concat_nl_and_code(data):
    data['text'] = format_str(data['query'] + '<CODESPLIT>' + data['func_code_string'])

    return data

ds_train_filtered = ds_train_filtered.map(concat_nl_and_code)
ds_train_filtered

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Dataset({
    features: ['func_code_string', 'label', 'query', 'text'],
    num_rows: 1
})

In [96]:
ds_train_filtered[0]

{'func_code_string': 'protected static Number stringToNumber(final String val) throws NumberFormatException {\n        char initial = val.charAt(0);\n        if ((initial >= \'0\' && initial <= \'9\') || initial == \'-\') {\n            // decimal representation\n            if (isDecimalNotation(val)) {\n                // quick dirty way to see if we need a BigDecimal instead of a Double\n                // this only handles some cases of overflow or underflow\n                if (val.length()>14) {\n                    return new BigDecimal(val);\n                }\n                final Double d = Double.valueOf(val);\n                if (d.isInfinite() || d.isNaN()) {\n                    // if we can\'t parse it as a double, go up to BigDecimal\n                    // this is probably due to underflow like 4.32e-678\n                    // or overflow like 4.65e5324. The size of the string is small\n                    // but can\'t be held in a Double.\n                    retur

In [97]:
ds_train_filtered.save_to_disk("data/code_search_net_relevance.hf")

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]