In [None]:
pip install tiktoken

In [None]:
import openai
import requests
import re
import os
import shutil
import time
from tqdm import tqdm
import tiktoken
from dotenv import load_dotenv, find_dotenv
from typing import List,Tuple,Dict
from datetime import timedelta
# from transformers import AutoTokenizer


_ = load_dotenv(find_dotenv())
openai.api_key  = os.getenv('OPENAI_API_KEY')
# Define regex patterns for comments
java_single_line_comment_regex = r"\/\/.*"
java_multiline_comment_regex = r"\/\*(?:[^*]|\*(?!\/))*\*\/"
kotlin_single_line_comment_regex = r"\/\/.*"
kotlin_multiline_comment_regex = r"\/\*[\s\S\n]*?\*\/"

def clear_response(response:str)->str:
    pattern = re.compile(r'\b[p|P]ackage\s+([\w.]+)\s*;')
    start = 0
    match = pattern.search(response)

    if match:
        start =  match.start()

    return  response[start:response.rfind('}')+1]
def find_last_brace(source_code: str) -> str:
    # Find the last '}' character
    last_brace_index = source_code.rfind('}')

    return last_brace_index

def remove_comments(content,file_type):
    """
    Opens a Java or Kotlin file, removes comments, and saves the changes.
    Args:
        file_path: The path to the Java or Kotlin file.
    """


    # Determine file type based on extension
    if file_type == "java":
        pattern = java_single_line_comment_regex + "|" + java_multiline_comment_regex
    elif file_type == "kotlin":
        pattern = kotlin_single_line_comment_regex + "|" + kotlin_multiline_comment_regex
    else:
        raise ValueError(f"Unsupported file type: {file_type}")

    # Remove comments using regex
    clean_content = re.sub(pattern, "", content, )
    return clean_content

def num_tokens_from_string(string: str, encoding_name: str =  "cl100k_base") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens
    # tokenizer = AutoTokenizer.from_pretrained("gpt-3.5-turbo")
    # return len(tokenizer.tokenize(string))


def get_completion(prompt, model="gpt-3.5-turbo"):
    messages = [{"role": "user", "content": prompt}]
    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        temperature=0, # this is the degree of randomness of the model's output
    )
    if response.choices[0]["finish_reason"] == "length":
        raise Exception("finish_reason == length")
    return response.choices[0].message["content"]

def remove_markdown(text):
    text = text.replace('```java','')
    return text.replace('```','')
#     code_regex = r"java(.*?)"
#     match = re.search(code_regex,text, re.DOTALL)
#     if match is not None:
#         return match.group(1)
#     else:
#         return text
import re
from typing import List

import re
from typing import List

def has_one_open_and_close_parenthesis(input_string):
    open_count = input_string.count('(')
    close_count = input_string.count(')')

    if open_count == 1 and close_count == 1:
        return input_string.find('(')
    else:
        return -1

def contains_forbidden_characters(input_string, forbidden_characters):
    return not any(char in forbidden_characters for char in input_string)

def clean_parameter_text(raw_params:str):

    raw_params = raw_params.replace('\n',' ')

    o_chars = {'<','{','('}
    c_chars = {'>','}',')'}
    char_counter = 0
    new_str=''
    for i in range(len(raw_params)):
        if raw_params[i]!=' ' or char_counter==0:
            if raw_params[i]==',' and char_counter==0:
                new_str+='|'
            else:
                new_str+=raw_params[i]

        if raw_params[i] in o_chars:
          char_counter+=1
        if raw_params[i] in c_chars:
          char_counter-=1
    return '|'.join([s.strip() for s in new_str.split('|') if s.strip()!=''])


def find_parentheses_indices(code: str, start_index: int, char = '()') -> tuple:
    open_parentheses_count = 0
    end_index = -1

    for i in range(start_index, len(code)):
        if code[i] == char[0]:
            if open_parentheses_count == 0:
                start_index = i
            open_parentheses_count += 1
        elif code[i] == char[1]:
            open_parentheses_count -= 1
            if open_parentheses_count == 0:
                end_index = i
                break
    return start_index, end_index



class Function:
    def __init__(self, name: str = "", parameter_types: List[str] = None):
        self.name = name

        # Convert types to a base form
        self.parameter_types = [Function.convert_primitive_type(t.strip().replace('?','')) for t in parameter_types] if parameter_types is not None else []



    @staticmethod
    def convert_primitive_type(kotlin_type: str) -> str:
        kotlin_to_java_primitive_mapping = {
            'Int': 'int',
            'Long': 'long',
            'Short': 'short',
            'Byte': 'byte',
            'Double': 'double',
            'Float': 'float',
            'Char': 'char',
            'Boolean': 'boolean',
            'String': 'String',  # Add other types if necessary
            'MutableSet':'Set',
            'Any':'Object'
        }
        for k,v in kotlin_to_java_primitive_mapping.items():
            kotlin_type = kotlin_type.replace(k,v)
        return kotlin_type

class KotlinParser:
    @staticmethod
    def get_classes(source_code: str) -> List[Function]:
        class_pattern = re.compile(r'\b(?:class|interface|enum class)\s+(\w+)')
        functions = []

        for match in class_pattern.finditer(source_code):
            class_name = match.group(1)
            functions.append(Function(name=class_name))

        return functions

    @staticmethod
    def remove_nullable_marker(param_type: str) -> str:
        # Remove '?' from the parameter type
        return param_type.replace('?','').strip()

    @staticmethod
    def get_functions(source_code: str) -> List[Function]:
        # Find all methods in the source code
        method_pattern = re.compile(r'fun\s+([\w.]+)\s*\(')
        matches = method_pattern.finditer(source_code)

        functions = []
        for match in matches:
            method_name = match.group(1)
            start_index = match.start()
            start,end = find_parentheses_indices(source_code,start_index)
            clean_params = clean_parameter_text(source_code[start+1:end])
            types = [p.split(':')[1] for p in clean_params.split('|')] if clean_params!="" else [""]
            functions.append(Function(name=method_name, parameter_types=types))
        return functions



class JavaParser:

    @staticmethod
    def get_classes(source_code: str) -> List[Function]:
        class_pattern = re.compile(r'\b(?:class|interface|enum)\s+(\w+)')
        functions = []

        for match in class_pattern.finditer(source_code):
            class_name = match.group(1)
            functions.append(Function(name=class_name))

        return functions


    @staticmethod
    def get_functions(source_code: str) -> List[Function]:
        method_pattern = re.compile(r'\b(?:public|private|protected|static|final|synchronized|abstract|native|strictfp)\s+[^=;{]*({|;)')
        matches = method_pattern.finditer(source_code)

        functions = []
        for match in matches:
            block = match.group()
            if '{' in block: # find method head / it ends with '{' or for interfaces end with ';' which is found by regex
                block = block[:block.find('{')]
            open_p = has_one_open_and_close_parenthesis(block)
            if open_p != -1: # find the open parantesis '(' index
                if contains_forbidden_characters(block[:open_p],set(['='])):
                    method_name = block[:open_p].split()[-1]

                    start_index = match.start()
                    start_index = block.find(method_name)

                    start,end = find_parentheses_indices(block,start_index)
                    clean_params = clean_parameter_text(block[start+1:end])
                    types = [p.split()[0] for p in clean_params.split('|')] if clean_params!="" else [""]
                    functions.append(Function(name=method_name, parameter_types=types))
        return functions



class CompareFiles:

    @staticmethod
    def remove_annotations(java_code: str) -> str:
        # Remove all annotations
        return re.sub(r'@\w+\s*', '', java_code)

    @staticmethod
    def check_package_declaration(kotlin_code: str,java_code: str) -> bool:
        package_declaration_pattern = re.compile(r'^\s*package\s+(\w+(\.\w+)*)\s*$', re.MULTILINE)
        match = package_declaration_pattern.search(kotlin_code)

        if not match:
          return False
        package = match.group(0).strip()
        return True if package in java_code else False


    @staticmethod
    def check_syntax_keywords(java_source: str) -> bool:
        java_keywords = ["fun", "val", "var"]  # Add more keywords as needed

        for keyword in java_keywords:
            if re.search(r'\b{}\b'.format(keyword), java_source):
                return True

        return False
    @staticmethod
    def compare_classes(kotlin_classes, java_classes) -> Tuple[str,int]:
        report = ""
        score_counter = 0
        common_classes = kotlin_classes.intersection(java_classes)
        missing_in_java = kotlin_classes - common_classes
        missing_in_kotlin = java_classes - common_classes

        # print("Common Classes:")
        # for common_class in common_classes:
        #     print(f"+ {common_class}")

        if missing_in_java:
            report+="\n//Classes missing in Java:\n"
            for missing_class in missing_in_java:
                report+=f"//- {missing_class}\n"
                score_counter+=1

        if missing_in_kotlin:
            report+="\n//Classes extra in Java:\n"
            for missing_class in missing_in_kotlin:
                report+=f"//+ {missing_class}\n"
                score_counter+=1
        return report,score_counter

    @staticmethod
    def compare_functions(kotlin_functions: List[Function], java_functions: List[Function]) -> Tuple[str,int]:
        report=""
        common_functions = []
        missing_in_java = []
        missing_in_kotlin = []
        score_counter=0

        # Compare functions by name and parameters
        for kotlin_func in kotlin_functions:
            matching_java_funcs = [java_func for java_func in java_functions if
                                   java_func.name == kotlin_func.name and java_func.parameter_types == kotlin_func.parameter_types]

            if matching_java_funcs:
                common_functions.extend(matching_java_funcs)
            else:
                missing_in_java.append(kotlin_func)

        # Find functions missing in Kotlin
        missing_in_kotlin = [java_func for java_func in java_functions if java_func not in common_functions]

        # print("\n common Functions:")
        # for common_func in common_functions:
        #     print(f"+ {common_func.name}({', '.join(common_func.parameter_types)})")

        if missing_in_java:
            report+="\n//Functions missing in Java:\n"
            for missing_func in missing_in_java:
                report+=f"//- {missing_func.name}({', '.join(missing_func.parameter_types)})\n"
                score_counter+=1
        if missing_in_kotlin:
            report+="\n//Functions extra in Java:\n"
            for missing_func in missing_in_kotlin:
                report+=f"//+ {missing_func.name}({', '.join(missing_func.parameter_types)})\n"
                score_counter+=1
        return report,score_counter

    @staticmethod
    def compare(kotlin_source: str, java_source: str) -> Tuple[str,int]:
        kotlin_parser = KotlinParser()
        java_parser = JavaParser()

        java_source = java_source[:java_source.rfind('}')+1]
        java_source = CompareFiles.remove_annotations(java_source)

        kotlin_classes = {cls.name for cls in kotlin_parser.get_classes(kotlin_source)}
        java_classes = {cls.name for cls in java_parser.get_classes(java_source)}

        kotlin_functions = kotlin_parser.get_functions(kotlin_source)
        java_functions = java_parser.get_functions(java_source)

        def find_parentheses_indices(code: str, char = '()') -> bool:
            open_parentheses_count = 0
            for i in range(len(code)):
                if code[i] == char[0]:
                    open_parentheses_count += 1
                elif code[i] == char[1]:
                    open_parentheses_count -= 1
            if open_parentheses_count == 0:
                return True
            return False

        syntax = find_parentheses_indices(java_source,'()')
        syntax = syntax and find_parentheses_indices(java_source,'{}')

        class_report , class_score = CompareFiles.compare_classes(kotlin_classes, java_classes)
        fun_report,fun_score = CompareFiles.compare_functions(kotlin_functions, java_functions)
        final_score =  class_score + fun_score + (0 if syntax else 100)


        report = "\n//⚠!#!" if final_score != 0 else ""
        report += "\n//--------------------Class--------------------"
        report += class_report
        report += "\n//-------------------Functions-----------------"
        report +=fun_report
        report += "\n//-------------------Extra---------------------"
        report += "\n//Bracket problem" if not syntax else ""
        report += "\n//Found syntax problems" if CompareFiles.check_syntax_keywords(java_source) else ""

        if CompareFiles.check_package_declaration(kotlin_source,java_source)==False:
            report += "\n//Issue with package decleration"

        report+="\n//---------------------------------------------"
        return report,final_score
# Example usage:
kotlin_code = """
package java.test
class A {


    @Test
    fun upvoteStory_whenUpvoteSuccessful() = runBlocking {
        // Given that the use case responds with success
        whenever(upvoteStory(storyId)).thenReturn(Result.Success(Unit))
        // And the view model is constructed
        val viewModel = withViewModel()
        var result: Result<Unit>? = null

        // When upvoting a story
        viewModel.storyUpvoteRequested(storyId) { result = it }

        // Then the result is successful
        assertEquals(Result.Success(Unit), result)
    }




    fun methodA(param1: Int, param2: String) {
        // Method body
    }

    fun method_A(){}

    fun methodB() {
        // Method body
    }
    fun anotation(drawerView : View?, slideOffset: Float) {
        super.onDrawerSlide(drawerView, slideOffset);
    }
}

class B {
    fun onlyKotlin() {
        // Method body
    }
}
"""

java_code = """
Package java.test;
class A {


    @Test
    public void upvoteStory_whenUpvoteSuccessful() throws Exception {
        whenever(upvoteStory.invoke(storyId)).thenReturn(Result.Success(null));

        StoryViewModel viewModel = withViewModel();
        Result<Object> result = null;

        viewModel.storyUpvoteRequested(storyId, (r) -> result = r);

        Assert.assertEquals(Result.Success(null), result);
    }





    public void methodA(int param1,
        String    param2) {
        // Method body
    }
    public static void methodB(){}
    private void method_A(){}

    public void onlyJava() {
        // Method body
    }
        public void anotation(@NonNull View drawerView, float slideOffset) {
        super.onDrawerSlide(drawerView, slideOffset);
    }

    @Test
    public void clickOnAndroidHomeIcon_OpensAndClosesNavigation() {
        Espresso.onView(withId(R.id.drawer_layout))
                .check(matches(DrawerMatchers.isClosed(Gravity.START)));

        clickOnHomeIconToOpenNavigationDrawer();
        checkDrawerIsOpen();
    }
}

class C {
    fun methodC() {
        // Method body
    }
Note: The `@Suppress("UNCHECKED_CAST")` class annotation in the Kotlin code is not needed in Java because the `create` method in the `ViewModelProvider.Factory` interface already has a generic type parameter `<T extends ViewModel>`.</s>
"""

# Comparing Kotlin and Java files
report,score = CompareFiles.compare(kotlin_code, java_code)
print(report)

In [None]:
proj_name = 'leakcanary_Diff'
merged_url = f'https://raw.githubusercontent.com/benymaxparsa/Kotlin_projects_commit_diff/main/{proj_name}-merged.txt'
paths_url = f'https://raw.githubusercontent.com/benymaxparsa/Kotlin_projects_commit_diff/main/{proj_name}-paths.txt'
dir_name = merged_url.split('/')[4]
print(dir_name)
print(paths_url)

r = requests.get(merged_url, allow_redirects=True)
open('merged.txt', 'wb').write(r.content)

with open('merged.txt','r') as file:
    content = file.read()

scripts = content.split('<code block>')
print(f"code blocks: {len(scripts)}")
# print(scripts[0].strip())


r = requests.get(paths_url, allow_redirects=True)
open('paths.txt', 'wb').write(r.content)

with open('paths.txt','r') as file:
    paths = file.readlines()

print(f"files: {len(paths)}")


# rename .kt to .java
new_paths = []
for i in range(len(paths)):
    new_paths.append(''.join(paths[i].strip().split('.')[:-1]) + '.java')
print(new_paths[0])
len(new_paths)

def get_file_name(index):
    return new_paths[index]

begin_dico = dict()
end_dico = dict()
def get_commit(path):
    return path.split('/')[2].split('-')[-1]

commits = [get_commit(c) for c in paths]
unique = set(commits)
for u in unique:
    begin_dico[u] = commits.index(u)
    end_dico[u] = len(commits) - commits[::-1].index(u)

longest = max(scripts, key = lambda x:len(x))
print(f"Longest files has << {num_tokens_from_string(longest)} >> tokens")


incomplete_files = []
translation_report =[]
def translate(files_index:List[int],commit_words=None,index_to_commit={},translate_all = True):
    result = []
    skipped = 0
    for i in tqdm(files_index):
        code = scripts[i]
        index = i
        prev_response = ""
        prev_score=-1
        text = code.strip()
        kotlin_classes = KotlinParser.get_classes(text)
        
        list_of_classes = None
        if commit_words != None:
            for data_tuple in commit_words:
                # Check if the first element of the tuple starts with the desired string
                if data_tuple[0].startswith(index_to_commit[i]):
                    list_of_classes = data_tuple[1]
        
        if not translate_all:
            list_of_classes = set(list_of_classes) - { 'changed',
 'class',
 'classes',
 'close',
 'from',
 'in',
 'with',
 'is',
 'of',
 'open',
 'to'}
            if (new_paths[index].split('/')[-1].split('.')[0] not in list_of_classes) and (not set([k.name for k in kotlin_classes]) & list_of_classes):
                skipped += 1
                translation_report.append(new_paths[index].split('/')[-1].split('.')[0])
                continue
#         print(f"Not Skipped: {new_paths[index].split('/')[-1].split('.')[0]}")
        prompts =[
        f"""Translate the given Kotlin code to Java, adhering to the following constraints:
{'In case functions  outside class exist in kotlin, put the translated java functions in a class using name of the file which is:' + new_paths[index].split('/')[-1].replace(".java","") if len(kotlin_classes)==0 else ""}
Preserve the original names of classes, fields, and methods without renaming.
Translate the entire class, including all its fields, methods, inner classes, etc.
Do not create any new classes, methods, or fields.
Make no assumptions about the code, meaning not helper classes or functions.
Keep the translation as close as possible to the original code.
Under no circumstances add new classes, methods, or fields beyond what's in the original.
Do not add new methods even if the reference is not available.
Here's the original code block:""",]
        for p in prompts:
            prompt = f"""{p}
```{text}```"""
            orginal_source_tokens = num_tokens_from_string(text)
            # max_tokens = (2*orginal_source_tokens) + int(extra_tokens_percentage*orginal_source_tokens) + extra_tokens

            time.sleep(5)  # Delay for 5 seconds
            try:
                response = get_completion(prompt)
                report,score = CompareFiles.compare(text, response)
                if prev_score == -1 or score < prev_score:
                    prev_score = score
                    prev_response = report + '\n' +clear_response(response)
            except Exception as e:
                if 'exceeded quota for this month' in str(e): # using the actual exception from the API did not work
                    raise Exception(f'exceeded quota for this month, till this index:{index}')
                response = f"File: {get_file_name(index)}\n⚠ Error: {e}"
                incomplete_files.append(get_file_name(index))
                report,score = CompareFiles.compare(text, response)
                if prev_score == -1 or score < prev_score:
                    prev_score = score
                    prev_response = report + '\n' +response
                # print(response)

            report,score = CompareFiles.compare(text, response)
            if prev_score == -1 or score < prev_score:
                prev_score = score
                prev_response = report + '\n' +clear_response(response)
            # print(f"\nTranslation score: {score}")


            if score==0 :
                break

        # if prev_score != 0 :
        #     incomplete_files.append(get_file_name(index))

        result.append([remove_markdown(prev_response),index])
    print(f"skipped: {skipped}")
    return result


def save_files(result):
    files = [[r[0],new_paths[r[1]]] for r in result]
    files_index = [r[1] for r in result]
    print(len(files))
    for file in files:
        path = os.path.dirname(file[1])
        if not os.path.exists(path):
            os.makedirs(path)
        with open(file[1].strip(),'w') as f:
            f.write(file[0])
    shutil.make_archive(f'GPT-{proj_name}-{files_index[0]}-{files_index[-1]}', 'zip', dir_name)
    shutil.rmtree(dir_name)

In [None]:
import json
import requests

def remove_between_parentheses(input_string):
    return re.sub(r'\([^)]*\)', '', input_string)

# Function to download JSON file from URL
def download_json(url):
    response = requests.get(url)
    if response.status_code == 200:
        return response.json()
    else:
        raise Exception("Failed to download JSON file")

# Function to extract description values and split them
commit_words = []
def extract_descriptions(data):
    for item in data:
        descriptions = []
        commit = item.get("sha1", '')
        refactorings = item.get("refactorings", [])
        for refactoring in refactorings:
            description = refactoring.get("description", "")
            if description:
                description = remove_between_parentheses(description)
                description = description.replace(".", " ")
                description = description.replace(",", " ")
                descriptions.extend(description.split())
                # descriptions.extend(description.split(" "))

        commit_words.append((commit,descriptions))


# URL of the JSON file
# Sunflower
# json_url = "https://drive.google.com/uc?export=download&id=1h1ZuljIs3sThLo31Nm2pkwOCBrz3uHLI"
# LeakCanary
json_url = "https://drive.google.com/uc?export=download&id=13lmA4E-bvVeqkt6gs7BCnCg9JpnsfogP"

# Download JSON file
json_data = download_json(json_url)

# Extract descriptions
extract_descriptions(json_data)

In [None]:
commits = ["297fa598",
          ]
chunk_size = 30
print(f"calculating files for each chunk of size {chunk_size}")


all_files_index = [] 
for commit in commits:
    all_files_index+=list(range(begin_dico[commit],end_dico[commit]))
print(len(all_files_index))

for i in range(0,len(all_files_index),chunk_size):
    print(all_files_index[i:i+chunk_size])
# for i in all_files_index:
#     print(new_paths[i])

print("\ncalculating index to commit dict")

index_to_commit = {}
for commit in commits:
    commit_indices= list(range(begin_dico[commit],end_dico[commit]))
    for i in commit_indices:
        index_to_commit[i]=commit
# index_to_commit

In [None]:
incomplete_files = []
for i in range(0,len(all_files_index),chunk_size):
    result = translate(all_files_index[i:i+chunk_size],commit_words,index_to_commit,translate_all=False)
    save_files(result)

print(f"{len(incomplete_files)} files are bad")

print("\nSkipped Files:")
for skipped_file in translation_report:
    print(skipped_file)

In [None]:
list_of_classes = None
if commit_words != None:
    for data_tuple in commit_words:
        # Check if the first element of the tuple starts with the desired string
        if data_tuple[0].startswith("297f"):
            list_of_classes = data_tuple[1]
set(list_of_classes)