## Code search java
## 代码搜索
 实现了一个简单的文件解析和从python文件中提取函数的信息[mybatis-plus-samples代码库](https://gitee.com/baomidou/mybatis-plus-samples)。
___

###  怎样获取Java方法

####  来自你的消息:
使用python，获取Java代码中，属于方法的代码，以及方法的名称

####  来自 ChatGPT 的消息:
可以通过正则表达式匹配获取Java代码中的方法名称和方法体。

示例代码：

```python
import re

# 假设Java代码存储在文件中
with open('example.java', 'r') as f:
    java_code = f.read()

# 定义正则表达式
method_pattern = r'(?:public|private|protected)?\s*(?:void|int|float|double|char|byte|short|long|boolean)\s+(\w+)\s*\((.*?)\)(?:(?:throws\s+\w+\s*,?)*\s*)\{([\s\S]*?)\}'

# 匹配方法名称和方法体
matches = re.findall(method_pattern, java_code)

# 输出结果
for match in matches:
    method_name = match[0]
    method_body = match[2]
    print('Method name:', method_name)
    print('Method body:', method_body)
```

该代码会将Java文件中的所有方法名称和方法体输出。其中，`method_name`即为方法名称，`method_body`为方法体（包括大括号内的所有代码）。正则表达式中的`method_pattern`用于匹配方法名称和方法体，可以根据需要进行修改。

#### 来自你的消息:
method_body 要包含整个方法内容，不仅仅是大括号内的所有代码

#### 来自 ChatGPT 的消息:
我理解您的意思是希望`method_body`包括整个方法内容，包括方法名、参数和方法体。在这种情况下，可以将方法名和参数也包含在正则表达式中，然后将整个方法的内容存储在`method_body`中。

示例代码：

```python
import re

# 假设Java代码存储在文件中
with open('example.java', 'r') as f:
    java_code = f.read()

# 定义正则表达式
method_pattern = r'(?:public|private|protected)?\s*(?:void|int|float|double|char|byte|short|long|boolean)\s+(\w+)\s*\((.*?)\)(?:(?:throws\s+\w+\s*,?)*\s*)\{([\s\S]*?)\}'

# 匹配方法名称、参数和方法体
matches = re.findall(method_pattern, java_code)

# 输出结果
for match in matches:
    method_name = match[0]
    method_params = match[1]
    method_body = match[2]
    method_content = f'{method_name}({method_params})' + '{' + method_body + '}'
    print('Method content:', method_content)
```

在这个版本的代码中，`method_body`仍然包含大括号内的所有代码，但是整个方法的内容现在可以通过如下方式计算得到：

```python
method_content = f'{method_name}({method_params})' + '{' + method_body + '}'
```

其中，`f'{method_name}({method_params})'`用于构造方法名和参数部分，`'{' + method_body + '}'`则包含整个方法体，最后将两者拼接在一起即可得到整个方法的内容。
___

In [1]:
import javalang
import os
# glob is used to get all files in a directory
# glob 是用来获取目录中的所有文件
# pandas is used to store the results
# pandas 用来存储结果
from glob import glob
# https://pandas.pydata.org/docs/index.html
import pandas as pd

import re



def get_functions(filepath):
    # 假设Java代码存储在文件中
    with open(filepath, 'r') as f:
        java_code = f.read()

    # 定义正则表达式
    method_pattern = r'(?:public|private|protected)?\s*(?:void|int|float|double|char|byte|short|long|boolean)\s+(\w+)\s*\((.*?)\)(?:(?:throws\s+\w+\s*,?)*\s*)\{([\s\S]*?)\}'

    # 匹配方法名称、参数和方法体
    matches = re.findall(method_pattern, java_code)

    # 输出结果
    for match in matches:
        method_name = match[0]
        method_params = match[1]
        method_body = match[2]
        method_content = f'{method_name}({method_params})' + '{' + method_body + '}'
        print('Method content:\n', method_content)
        yield {"code": method_content, "function_name": method_name, "filepath": filepath}

# get user root directory
# 获取用户根目录
root_dir = os.path.expanduser("~")
# note: for this code to work, the openai-python repo must be downloaded and placed in your root directory
# 注意：要使此代码工作，必须下载openai-python repo并将其放在您的根目录中
# path to code repository directory
code_root = root_dir + "/WorkSpaces/IntelliJIdeaPorjects/mybatis-plus-samples"
# os.walk() 方法用于通过在目录树中游走输出在目录中的文件名，向上或者向下。
# os.path.join() 方法用于将序列中的元素以指定的字符连接生成一个新的字符串。
code_files = [y for x in os.walk(code_root) for y in glob(os.path.join(x[0], '*.java'))]
print("Total number of py files:", len(code_files))

with open("data/code_files.text", 'w') as f:
    for code_file in code_files:
        f.write(code_file)
        f.write("\n")

if len(code_files) == 0:
    print("Double check that you have downloaded the openai-python repo and set the code_root variable correctly.")

all_funcs = []
# get all functions in all files
for code_file in code_files:
    funcs = list(get_functions(code_file))
    for func in funcs:
        all_funcs.append(func)

print("Total number of functions extracted:", len(all_funcs))




Total number of py files: 197
Method content:
 t_c(){
        final Child child = childMapper.selectLinkById(1L);
        log.info("child: {}
Method content:
 t_m(){
        final Man man = manMapper.selectLinkById(1L);
        log.info("man: {}
Method content:
 t_w(){
        final Woman woman = womanMapper.selectLinkById(1L);
        log.info("woman: {}
Method content:
 main(String[] args){
        SpringApplication.run(ResultmapApplication.class, args);
    }
Method content:
 aInsert(){
        User user = new User();
        user.setName("小羊");
        user.setAge(3);
        user.setEmail("abc@mp.com");
        assertThat(mapper.insert(user)).isGreaterThan(0);
        // 成功直接拿回写的 ID
        assertThat(user.getId()).isNotNull();
    }
Method content:
 bDelete(){
        assertThat(mapper.deleteById(3L)).isGreaterThan(0);
        assertThat(mapper.delete(new QueryWrapper<User>()
                .lambda().eq(User::getName, "Sandy"))).isGreaterThan(0);
    }
Method content:
 cUpdate()

In [6]:
!pip install --upgrade plotly
!pip install --upgrade matplotlib

Collecting plotly
  Downloading plotly-5.14.1-py2.py3-none-any.whl (15.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.3/15.3 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
Installing collected packages: plotly
Successfully installed plotly-5.14.1


In [13]:
# get embeddings for all functions
# 获取所有函数的嵌入
from openai.embeddings_utils import get_embedding

# pandas.DataFrame() 函数用于将字典或 Series 对象转换为 DataFrame 对象。
df = pd.DataFrame(all_funcs)
# pandas.DataFrame.apply() 函数用于对 DataFrame 的每一行或每一列应用函数。
df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, engine='text-embedding-ada-002'))
# pandas.DataFrame.to_csv() 函数用于将 DataFrame 写入 CSV 文件。
df['filepath'] = df['filepath'].apply(lambda x: x.replace(code_root, ""))
# 保存到csv文件
df.to_csv("data/code_search_mybatis-plus-samples.csv", index=False)
# pandas.DataFrame.head() 函数用于返回前 n 行。
df.head()

Unnamed: 0,code,function_name,filepath,code_embedding
0,t_c(){\n final Child child = childMappe...,t_c,/mybatis-plus-sample-resultmap/src/test/java/c...,"[0.01701400801539421, 0.014936479739844799, -0..."
1,t_m(){\n final Man man = manMapper.sele...,t_m,/mybatis-plus-sample-resultmap/src/test/java/c...,"[-0.006725418381392956, -0.00554285105317831, ..."
2,t_w(){\n final Woman woman = womanMappe...,t_w,/mybatis-plus-sample-resultmap/src/test/java/c...,"[-0.02467336133122444, -0.008370584808290005, ..."
3,main(String[] args){\n SpringApplicatio...,main,/mybatis-plus-sample-resultmap/src/main/java/c...,"[0.0033424673601984978, 0.007332277484238148, ..."
4,aInsert(){\n User user = new User();\n ...,aInsert,/mybatis-plus-sample-crud/src/test/java/com/ba...,"[-0.030493764206767082, 0.008744916878640652, ..."


In [14]:
# cosine_similarity 是余弦相似度 用于计算两个向量的相似度
from openai.embeddings_utils import cosine_similarity
"""
Search for functions in the code repository that are similar to a given query.
查询与给定查询类似的代码库中的函数。
"""
def search_functions(df, code_query, n=3, pprint=True, n_lines=7):
    embedding = get_embedding(code_query, engine='text-embedding-ada-002')
    # pandas.DataFrame.apply() 函数用于对 DataFrame 的每一行或每一列应用函数。
    df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))
    # pandas.DataFrame.sort_values() 函数用于对 DataFrame 的列进行排序。
    res = df.sort_values('similarities', ascending=False).head(n)
    # pandas.DataFrame.iterrows() 函数用于迭代 DataFrame 的行。
    if pprint:
        for r in res.iterrows():
            print(r[1].filepath+":"+r[1].function_name + "  score=" + str(round(r[1].similarities, 3)))
            print("\n".join(r[1].code.split("\n")[:n_lines]))
            print('-'*70)
    return res

res = search_functions(df, 'testCustomSqlSegment', n=3)

/mybatis-plus-sample-deluxe/src/test/java/com/baomidou/mybatisplus/samples/deluxe/DeluxeTest.java:testCustomSqlSegment  score=0.888
testCustomSqlSegment(){
        QueryWrapper<User> ew = new QueryWrapper<>();
        ew.like("u.name", "Tom");
        List<User> list = mapper.customerSqlSegment(ew);
        Assertions.assertEquals(1, list.size());
    }
----------------------------------------------------------------------
/mybatis-plus-sample-quickstart-springmvc/src/test/java/com/baomidou/mybatisplus/samples/quickstart/springmvc/SpringMvcTest.java:testCustomizedSql  score=0.826
testCustomizedSql(){
        System.out.println("maxAge=" + userMapper.selectMaxAge());
    }
----------------------------------------------------------------------
/mybatis-plus-sample-reduce-springmvc/src/test/java/com/baomidou/mybatisplus/samples/reduce/springmvc/ReduceTest.java:testCustomizedSql  score=0.826
testCustomizedSql(){
        System.out.println("maxAge=" + userMapper.selectMaxAge());
    }
-----

In [15]:
res = search_functions(df, 'verifyGithub1532', n=3)

/mybatis-plus-sample-deluxe/src/test/java/com/baomidou/mybatisplus/samples/deluxe/DeluxeTest.java:verifyGithub1532  score=0.846
verifyGithub1532(){
        mapper.findList(new User().setName("a")).forEach(System.out::println);
    }
----------------------------------------------------------------------
/mybatis-plus-sample-tenant/src/test/java/com/baomidou/mybatisplus/samples/tenant/TenantTest.java:manualSqlTenantFilterTest  score=0.75
manualSqlTenantFilterTest(){
        System.out.println(mapper.myCount());
    }
----------------------------------------------------------------------
/mybatis-plus-sample-id-generator/src/main/java/com/baomidou/samples/IdGeneratorApplication.java:main  score=0.74
main(String[] args){
        SpringApplication.run(IdGeneratorApplication.class, args);
    }
----------------------------------------------------------------------


In [16]:
res = search_functions(df, 'DdlMysqlApplication', n=1, n_lines=20)

/mybatis-plus-sample-ddl-mysql/src/main/java/com/baomidou/mybatisplus/samples/ddl/mysql/DdlMysqlApplication.java:main  score=0.865
main(String[] args){
        SpringApplication.run(DdlMysqlApplication.class, args);
    }
----------------------------------------------------------------------
