In [36]:
import pandas as pd
from huggingface_hub import login
from datasets import load_dataset
import time
import requests
from tqdm import tqdm

# JavaDoc-Code Similarity
### Login to Huggingface

In [3]:
with open('secrets/hugging_face_key.txt') as f:
    login(f.read())

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/marcus/.cache/huggingface/token
Login successful


### Load dataset

In [8]:
ds = load_dataset("code_search_net", "java", split='train', streaming=True)

In [11]:
row = next(iter(ds))
print(row.keys())
print(row['func_documentation_string'])
print(row['func_code_string'])
print(row['repository_name'])

dict_keys(['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'])
Bind indexed elements to the supplied collection.
@param name the name of the property to bind
@param target the target bindable
@param elementBinder the binder to use for elements
@param aggregateType the aggregate type, may be a collection or an array
@param elementType the element type
@param result the destination for results
protected final void bindIndexed(ConfigurationPropertyName name, Bindable<?> target,
			AggregateElementBinder elementBinder, ResolvableType aggregateType,
			ResolvableType elementType, IndexedCollectionSupplier result) {
		for (ConfigurationPropertySource source : getContext().getSources()) {
			bindIndexed(source, name, target, elementBinder, result, aggregateType,
					elementType);
			if (result.wasSupplied() && result.get() !

#### Get stargazers for repo

In [63]:
repo_stars = {}

headers = {"Accept": "application/vnd.github+json"}

def get_github_data(repo):
    if repo in repo_stars.keys():
        return repo_stars[repo]

    response = requests.get('https://api.github.com/repos/' + repo, headers=headers)

    if response.status_code != 200:
        print(response.status_code)
        return None

    print("Getting data from GitHub for: " + repo)
    stars = response.json()['stargazers_count']
    repo_stars[repo] = stars

    return stars

In [68]:
get_github_data(row['repository_name'])

68742

In [65]:
repo_stars

{'spring-projects/spring-boot': 68742}

### Preprocess data

In [69]:
def preprocess(row):
    stars = get_github_data(row['repository_name'])

    return pd.DataFrame({"docstring": row['func_documentation_string'], "code": row['func_code_string'], 'stars': stars, 'repo': row['repository_name']}, index=[0])

In [70]:
test = preprocess(row)
test

Unnamed: 0,docstring,code,stars,repo
0,Stop the application managed by this instance....,public void stop()\n\t\t\tthrows MojoExecution...,68742,spring-projects/spring-boot


### Process Data

In [71]:
# Used to check of docstring is written in a different language other than English.
def is_ascii(s):
    return all(ord(c) < 128 for c in s)

In [72]:
NUMBER_OF_CLASSES = 100
WAIT_TIME = 3

df = pd.DataFrame()

for i, row in tqdm(enumerate(iter(ds))):
    if not is_ascii(row['func_documentation_string']):
        continue

    proc_df = preprocess(row)
    df = pd.concat([df, proc_df], ignore_index=True)

    time.sleep(WAIT_TIME)

    if i == NUMBER_OF_CLASSES:
        break

df.head()

100it [05:00,  3.01s/it]


Unnamed: 0,docstring,code,stars,repo
0,Bind indexed elements to the supplied collecti...,protected final void bindIndexed(Configuration...,68742,spring-projects/spring-boot
1,Set {@link ServletRegistrationBean}s that the ...,public void setServletRegistrationBeans(\n\t\t...,68742,spring-projects/spring-boot
2,Add {@link ServletRegistrationBean}s for the f...,public void addServletRegistrationBeans(\n\t\t...,68742,spring-projects/spring-boot
3,Set servlet names that the filter will be regi...,public void setServletNames(Collection<String>...,68742,spring-projects/spring-boot
4,Add servlet names for the filter.\n@param serv...,public void addServletNames(String... servletN...,68742,spring-projects/spring-boot


In [73]:
df.to_csv('data/processed.csv')