In [2]:
import requests
import os
import time
import subprocess
import shutil
import csv
import re
import random
import stat
import errno
import pandas as pd
import javalang

GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")

# API Endpoints
GITHUB_API_URL = "https://api.github.com"
SEARCH_REPOS_ENDPOINT = f"{GITHUB_API_URL}/search/repositories"

# Target data sizes
TARGET_TRAIN_SIZE = 25000
TARGET_TEST_SIZE = 5000
TOTAL_TARGET_SIZE = TARGET_TRAIN_SIZE + TARGET_TEST_SIZE

def search_java_repos(query, per_page=100, pages=1):
    """
    Searches for popular Java repositories on GitHub based on stars.
    """
    repos = []
    headers = {"Authorization": f"token {GITHUB_TOKEN}"} if GITHUB_TOKEN else {} 
    for page in range(1, pages + 1):
        params = {
            "q": query,
            "sort": "stars",
            "order": "desc",
            "per_page": per_page,
            "page": page
        }        
        try:
            response = requests.get(SEARCH_REPOS_ENDPOINT, headers=headers, params=params)
            response.raise_for_status() # Raise an exception for bad status codes            
            data = response.json()
            items = data.get("items", [])
            repos.extend(items)            
            print(f"Found {len(items)} repositories on page {page}.")
            time.sleep(1)           
        except Exception as e:
            print(f"Error searching for repositories: {e}")
            break
    return repos


def clone_repo(repo_url, clone_dir):
    """
    Clones a repository to a local directory.
    """
    print(f"Cloning repository from '{repo_url}'...")
    try:
        subprocess.run(["git", "clone", "--depth", "1", repo_url, clone_dir], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        print(f"Successfully cloned to '{clone_dir}'.")
    except Exception as e:
        print(f"Error cloning repository: {e}")

        
def get_latest_commit_sha(repo_dir):
    """
    Gets the latest commit SHA for a cloned repository.
    """
    try:
        result = subprocess.run(
            ["git", "rev-parse", "HEAD"],
            cwd=repo_dir,
            check=True,
            capture_output=True,
            text=True
        )
        return result.stdout.strip()
    except Exception as e:
        print(f"Error getting commit SHA for {repo_dir}: {e}")
        return None

        
def get_java_files_local(repo_dir):
    """
    Finds all .java files in a local directory.
    """
    java_files = []
    for root, _, files in os.walk(repo_dir):
        for file in files:
            if file.endswith(".java"):
                java_files.append(os.path.join(root, file))
    return java_files

def cleanup_repo(repo_dir):
    """
    Deletes the local repository directory.
    """
    print(f"Cleaning up: Deleting '{repo_dir}'.")
    try:
        shutil.rmtree(repo_dir, onerror=on_rm_error)
        print("Cleanup successful.")
    except Exception as e:
        print(f"Error deleting directory: {e}")

        
def parse_java_methods(java_code):
    """
    Parses Java code by tracking curly brace nesting to extract methods.
    """
    methods = []
    lines = java_code.splitlines()
    
    # This regex is now only used to find the start of a method signature.
    method_signature_regex = re.compile(
        r'^\s*(?:public|protected|private|static|final|\s)+[\w<>\.\s,\[\]]+\s+(\w+)\s*\(.*?\)\s*(?:throws\s+[\w\s,]+)?\s*\{',
        re.MULTILINE | re.DOTALL
    )
    current_line_number = 0
    while current_line_number < len(lines):
        line = lines[current_line_number]
        match = method_signature_regex.search(line)
        if match:
            brace_count = 1
            start_line_number = current_line_number + 1
            signature = line.strip()
            while not signature.endswith('{') and current_line_number + 1 < len(lines):
                current_line_number += 1
                signature += " " + lines[current_line_number].strip()
            signature = signature.strip().rstrip('{').strip()
            method_name = match.group(1)
            original_code_lines = [lines[start_line_number-1]]
            i = start_line_number
            while i < len(lines):
                current_line = lines[i]
                original_code_lines.append(current_line)
                brace_count += current_line.count('{')
                brace_count -= current_line.count('}')
                if brace_count == 0:
                    end_line_number = i + 1
                    original_code = "\n".join(original_code_lines).strip()
                    code_tokens = original_code.split()
                    methods.append({
                        "method_name": method_name,
                        "start_line": start_line_number,
                        "end_line": end_line_number,
                        "signature": signature,
                        "original_code": original_code,
                        "code_tokens": " ".join(code_tokens)
                    })
                    current_line_number = i
                    break
                i += 1
        current_line_number += 1  
    return methods


def write_data_to_csv(data, filename):
    """
    Writes a list of data dictionaries to a single CSV file.
    """
    print(f"Writing {len(data)} entries to '{filename}'.")
    fieldnames = [
        "dataset_split", "repo_name", "repo_url", "commit_sha", 
        "file_path", "method_name", "start_line", "end_line", 
        "signature", "original_code", "code_tokens"
    ]
    with open(filename, 'w', newline='', encoding='utf-8') as file:
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(data)
    print(f"Successfully wrote data to '{filename}'.")


def tokenize_java(code):
    """
    Tokenizes a Java code string using javalang.
    Returns a list of tokens as strings.
    """
    try:
        tokens = list(javalang.tokenizer.tokenize(code))
        return [t.value for t in tokens]
    except:
        # return empty list if tokenization fails
        return []

In [4]:
def main():
    # if not GITHUB_TOKEN:
    #     print("Warning: GITHUB_TOKEN environment variable not set. Rate limits will be very strict.")
    #     print("Please set your token to run this script effectively.")    
    all_methods = []
    seen_methods = set()
    query = "language:Java"
    repos = search_java_repos(query, pages=1)
    if not repos:
        print("No repositories found. Exiting.")
        return
    
    java_keywords = {"class", "public", "private", "protected", "static", "void", "String", "int", "double", "boolean", "if", 
                     "else", "for", "while", "return"}
    for repo in repos:
        if len(all_methods) >= TOTAL_TARGET_SIZE:
            print(f"Method count reached {TOTAL_TARGET_SIZE}. Stopping further repository processing.")
            break     
        repo_name = repo["full_name"]
        repo_url = repo["html_url"]
        repo_dir = f"./cloned_repos/{repo_name.replace('/', '_')}"
        print(f"\n--- Processing repository: {repo_name} ---")
        
        # Clone the queried repos 
        clone_repo(repo_url, repo_dir)
        if os.path.exists(repo_dir):
            # Get the latest commit SHA
            latest_sha = get_latest_commit_sha(repo_dir)
            # Find the cloned repos in local directory
            java_files = get_java_files_local(repo_dir)
            if not java_files:
                print("No Java files found. Skipping to next repo.")
                cleanup_repo(repo_dir)
                continue  
            for file_path in java_files:
                try:
                    with open(file_path, "r", encoding="utf-8") as f:
                        java_code = f.read()
                        methods = parse_java_methods(java_code)
                        for method in methods:
                            # Create a unique identifier to check for duplicates
                            unique_id = f"{repo_name}_{file_path.replace(repo_dir + os.sep, '')}_{method['signature']}"
                            if unique_id in seen_methods:
                                continue
                            code_tokens = set(method["original_code"].split())
                            if not code_tokens.intersection(java_keywords):
                                continue
                            all_methods.append({
                                "repo_name": repo_name,
                                "repo_url": repo_url,
                                "commit_sha": latest_sha,
                                "file_path": file_path.replace(repo_dir + os.sep, ""),
                                **method
                            })
                            seen_methods.add(unique_id)            
                except Exception as e:
                    print(f"Error reading file '{file_path}': {e}")
            
            # Delete the cloned repository after processing
            cleanup_repo(repo_dir)
            
    # Check if data count exceeds the required count
    if len(all_methods) < TOTAL_TARGET_SIZE:
        print(f"\nCould only collect {len(all_methods)} methods, which is less than the target of {TOTAL_TARGET_SIZE}.")
        print("Try increasing the query size.")        
    else:
        # Shuffle the methods to ensure a random distribution
        random.shuffle(all_methods)
        # Add the 'dataset_split' field and write to a single CSV
        for i, method in enumerate(all_methods[:TOTAL_TARGET_SIZE]):
            if i < TARGET_TRAIN_SIZE:
                method["dataset_split"] = "train"
            else:
                method["dataset_split"] = "test"      
        write_data_to_csv(all_methods[:TOTAL_TARGET_SIZE], "methods1.csv")  

    #tokenize the collected methods
    df = pd.read_csv("methods1.csv")
    df['tokenized_code'] = df['original_code'].apply(tokenize_java)
    df.to_csv("methods_tokenized1.csv", index=False)
    print("\n--- Process finished ---")

In [5]:
if __name__ == "__main__":
    # Create the directory to store cloned repositories
    if not os.path.exists("./cloned_repos"):
        os.makedirs("./cloned_repos")   
    main()

Found 100 repositories on page 1.

--- Processing repository: Snailclimb/JavaGuide ---
Cloning repository from 'https://github.com/Snailclimb/JavaGuide'...
Successfully cloned to './cloned_repos/Snailclimb_JavaGuide'.
No Java files found. Skipping to next repo.
Cleaning up: Deleting './cloned_repos/Snailclimb_JavaGuide'.
Error deleting directory: name 'on_rm_error' is not defined

--- Processing repository: krahets/hello-algo ---
Cloning repository from 'https://github.com/krahets/hello-algo'...
Successfully cloned to './cloned_repos/krahets_hello-algo'.
Cleaning up: Deleting './cloned_repos/krahets_hello-algo'.
Error deleting directory: name 'on_rm_error' is not defined

--- Processing repository: GrowingGit/GitHub-Chinese-Top-Charts ---
Cloning repository from 'https://github.com/GrowingGit/GitHub-Chinese-Top-Charts'...
Successfully cloned to './cloned_repos/GrowingGit_GitHub-Chinese-Top-Charts'.
Cleaning up: Deleting './cloned_repos/GrowingGit_GitHub-Chinese-Top-Charts'.
Error delet