# GPT-3 embeddings for code on IJ Community

ToC

## Java & Kotlin Parsers

In [None]:
%pip install -r requirement.txt
## tree-sitter, openAI

# mkidr -p parsers
!git clone 'https://github.com/tree-sitter/tree-sitter-java' parsers/tree-sitter-java
!git clone 'https://github.com/fwcd/tree-sitter-kotlin' parsers/tree-sitter-kotlin

In [61]:
from tree_sitter import Language, Parser

Language.build_library( # Build and store native libraries for parsers
  'build/my-languages.so',
  [
    'parsers/tree-sitter-java',
    'parsers/tree-sitter-kotlin'
  ]
)

True

In [103]:
# both queries include decaration \w empty body (e.g from interfaces)

JA_LANGUAGE = Language('build/my-languages.so', 'java')
jParser = Parser()
jParser.set_language(JA_LANGUAGE)
jQuery = JA_LANGUAGE.query("""
(method_declaration
  name: (identifier) @method.name) @method.decl
""")

KT_LANGUAGE = Language('build/my-languages.so', 'kotlin')
kParser = Parser()
kParser.set_language(KT_LANGUAGE)
kQuery = KT_LANGUAGE.query("""
(function_declaration
  (simple_identifier)+ @func.name) @func.decl
""")

In [104]:
jTree = jParser.parse(bytes("""
public class Test {
	boolean isValid();
	public static void main(String[] args){
		System.out.println("Hello, World!");
	}
	fianl String test;
	static @NotNull ModalTaskOwner project(@NotNull Project project) {
		return ApplicationManager.getApplication().getService(TaskSupport.class).modalTaskOwner(project);
	}
}
""", "utf8"))

jcaptures = jQuery.captures(jTree.root_node)
assert len(jcaptures) == 6
[(c[0].type, c[0].text) for c in jcaptures]

[('method_declaration', b'boolean isValid();'),
 ('identifier', b'isValid'),
 ('method_declaration',
  b'public static void main(String[] args){\n\t\tSystem.out.println("Hello, World!");\n\t}'),
 ('identifier', b'main'),
 ('method_declaration',
  b'static @NotNull ModalTaskOwner project(@NotNull Project project) {\n\t\treturn ApplicationManager.getApplication().getService(TaskSupport.class).modalTaskOwner(project);\n\t}'),
 ('identifier', b'project')]

In [105]:
kTree = kParser.parse(bytes("""
package org.kotlinlang.play         // 1

fun main() {                        // 2
    println("Hello, World!")        // 3
}                                   // 4

fun onChanged()

fun multiply(x: Int, y: Int = N) = x * y// 5

override fun getMessageBus(): MessageBus {
    error("Not supported")
}

class Customer                      // 6
""", "utf8"))

kcaptures = kQuery.captures(kTree.root_node)
# assert len(kcaptures) == 6
[(c[0].type, c[0].text) for c in kcaptures]

[('function_declaration',
  b'fun main() {                        // 2\n    println("Hello, World!")        // 3\n}'),
 ('simple_identifier', b'main'),
 ('simple_identifier', b'main'),
 ('function_declaration', b'fun onChanged()'),
 ('simple_identifier', b'onChanged'),
 ('simple_identifier', b'onChanged'),
 ('function_declaration', b'fun multiply(x: Int, y: Int = N) = x * y'),
 ('simple_identifier', b'multiply'),
 ('simple_identifier', b'multiply'),
 ('function_declaration',
  b'override fun getMessageBus(): MessageBus {\n    error("Not supported")\n}'),
 ('simple_identifier', b'getMessageBus'),
 ('simple_identifier', b'getMessageBus')]

## Parse & extract funtions

In [88]:
import codecs
import os
import re
from pathlib import Path
all_funcs = []
parsers = { ".java": jParser, ".kt": kParser}
queries =  { ".java": jQuery, ".kt": kQuery}

code_root = "../intellij-community/platform"
code_files = "ij-communit-platform-files.csv"
code_functions = "ij-communit-platform-functions.csv"
decl_re = re.compile('.*\.decl')
name_re = re.compile('.*\.name')

def get_functions(rel_path: str):
    file = os.path.join(code_root, rel_path)
    # print(f"reading '{file}'")
    with open(file, 'rb') as f:
        ext = rel_path[rel_path.rindex("."):]
        p = parsers[ext]
        tree = p.parse(f.raw.readall())

        q = queries[ext]
        cs = iter(q.captures(tree.root_node))

        for c in cs:
            fn_code = ""
            try:
                if decl_re.match(c[1]):
                    fn_code = codecs.utf_8_decode(c[0].text)[0]
                    next_c = next(cs)
                    fn_name = codecs.utf_8_decode(next_c[0].text)[0]
                    if ext == ".kt": # Kotlin parser matches the name twice ¯\_(ツ)_/¯
                        assert next_c[0].text == next(cs)[0].text, f"{rel_path}: {fn_code}"
                elif name_re.match(c[1]):
                    fn_name = codecs.utf_8_decode(c[0].text)[0]
                else:
                    assert False, "neither declaration nor name"
                yield {"code": fn_code, "function_name": fn_name, "filepath": rel_path}
            except StopIteration as e:
                print(f"{rel_path}: no name after {fn_code}")

with open(code_files) as f:
    # read, parse and query every file (find all the functions)
    for code_file in f.readlines():
        funcs = list(get_functions(code_file.strip()))
        for func in funcs:
            all_funcs.append(func)


In [283]:
import pandas as pd

code_functions = "ij-communit-platform-functions.gz.csv"

df = pd.DataFrame(all_funcs)
df.to_csv(code_functions, index=False, compression='gzip')
df.head()

Unnamed: 0,code,function_name,filepath
0,@Override\n protected void update(final AnAct...,update,lang-api/src/com/intellij/execution/ui/actions...
1,private boolean isToFocus(final ViewContext co...,isToFocus,lang-api/src/com/intellij/execution/ui/actions...
2,@Override\n protected void actionPerformed(fi...,actionPerformed,lang-api/src/com/intellij/execution/ui/actions...
3,@Override\n public @NotNull ActionUpdateThrea...,getActionUpdateThread,lang-api/src/com/intellij/execution/ui/actions...
4,@Override\n public final void update(final @N...,update,lang-api/src/com/intellij/execution/ui/actions...


## Tokenize & embeddings

In [95]:
# count tokens
model="text-embedding-ada-002"

import tiktoken
tokenizer = encoding = tiktoken.encoding_for_model(model) #tiktoken.get_encoding("cl100k_base")

df['n_tokens'] = df['code'].apply(lambda x: len(tokenizer.encode(x)))


In [34]:
df.sample(5)

Unnamed: 0,code,function_name,filepath,n_tokens
145946,static int getUsedHashAlgorithmVersion() {\n ...,getUsedHashAlgorithmVersion,indexing-impl/src/com/intellij/psi/impl/cache/...,25
172798,private void resizeComboToFitCustomShortcut() ...,resizeComboToFitCustomShortcut,lang-impl/src/com/intellij/codeInsight/templat...,49
56300,public boolean postponeValidation() {\n ret...,postponeValidation,lang-impl/src/com/intellij/ide/util/projectWiz...,12
49396,@Override\n public Dimension getPreferredSize...,getPreferredSize,platform-impl/src/com/intellij/ui/EditorComboB...,97
26268,public int getEnd() {\n return myEnd;\n },getEnd,core-impl/src/com/intellij/ide/highlighter/cus...,13


In [89]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 175311 entries, 0 to 175310
Data columns (total 4 columns):
 #   Column         Non-Null Count   Dtype 
---  ------         --------------   ----- 
 0   code           175311 non-null  object
 1   function_name  175311 non-null  object
 2   filepath       175311 non-null  object
 3   n_tokens       175311 non-null  int64 
dtypes: int64(1), object(3)
memory usage: 5.4+ MB


In [96]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 175242 entries, 0 to 175241
Data columns (total 4 columns):
 #   Column         Non-Null Count   Dtype 
---  ------         --------------   ----- 
 0   code           175242 non-null  object
 1   function_name  175242 non-null  object
 2   filepath       175242 non-null  object
 3   n_tokens       175242 non-null  int64 
dtypes: int64(1), object(3)
memory usage: 5.3+ MB


In [90]:
df.describe()

Unnamed: 0,n_tokens
count,175311.0
mean,68.900126
std,126.841344
min,0.0
25%,20.0
50%,34.0
75%,73.0
max,13406.0


In [97]:
df.describe()

Unnamed: 0,n_tokens
count,175242.0
mean,68.927255
std,126.858943
min,3.0
25%,20.0
50%,34.0
75%,73.0
max,13406.0


In [91]:
df[df['n_tokens'] == 0]

Unnamed: 0,code,function_name,filepath,n_tokens
4399,,DEFAULT_NO_EMAIL_ZENDESK_REQUESTER,feedback/src/com/intellij/feedback/common/Gene...,0
4991,,defaultPathQuery,built-in-server/src/org/jetbrains/builtInWebSe...,0
4993,,defaultPathQuery,built-in-server/src/org/jetbrains/builtInWebSe...,0
10790,,id,platform-impl/src/com/intellij/diagnostic/hpro...,0
10792,,id,platform-impl/src/com/intellij/diagnostic/hpro...,0
...,...,...,...,...
146111,,alwaysTrue,util-ex/src/com/intellij/psi/util/psiTreeUtil.kt,0
160190,,DEFAULT_NOTIFICATION_STATUS,diff-impl/src/com/intellij/diff/tools/util/Dif...,0
160192,,DEFAULT_NOTIFICATION_STATUS,diff-impl/src/com/intellij/diff/tools/util/Dif...,0
163968,,DEFAULT_LAMBDA,ml-impl/src/com/intellij/internal/ml/ngram/NGr...,0


In [99]:
df[df['n_tokens'] == 3]

Unnamed: 0,code,function_name,filepath,n_tokens
699,void cleanup();,cleanup,analysis-api/src/com/intellij/codeInspection/l...,3
718,void cleanup();,cleanup,analysis-api/src/com/intellij/codeInspection/l...,3
1076,void run();,run,editor-ui-api/src/com/intellij/openapi/actionS...,3
1200,boolean isValid();,isValid,usageView/src/com/intellij/usages/Usage.java,3
1308,boolean isValid();,isValid,usageView/src/com/intellij/usages/UsageTarget....,3
...,...,...,...,...
170086,void cancel();,cancel,core-impl/src/com/intellij/concurrency/Job.java,3
170739,void split();,split,platform-api/src/com/intellij/ui/content/Tabbe...,3
172463,void cancel();,cancel,core-api/src/org/jetbrains/concurrency/Cancell...,3
173406,void apply();,apply,lang-impl/src/com/intellij/ide/util/gotoByName...,3


In [141]:
df['n_tokens'].sum()

12078950

In [143]:
df[df['n_tokens'] > 8192]

Unnamed: 0,code,function_name,filepath,n_tokens,tokens
77709,private static int[] splitIntegerList() {\n ...,splitIntegerList,util/zip/src/org/jetbrains/ikv/RecSplitSetting...,13406,"[2039, 1118, 528, 1318, 6859, 3570, 861, 368, ..."


In [107]:
print(f"Price: ${(df['n_tokens'].sum() / 1000) * 0.0004:.2f}")
print(f"Min time: {((df['n_tokens'].sum() / 2048) / 3500):.2f} min")
# ~5.900 request total, 2k each

# limit
#  3,500 RPM
#  350,000 TPM

Price: $4.83


Time: 1.69 min


In [133]:
import numpy as np

df['tokens'] = df['code'].apply(lambda x: tokenizer.encode(x))

In [148]:
type(df['tokens'][0])

list

In [175]:
import numpy as np
import openai
# from openai.embeddings_utils import get_embedding

openai.api_key = os.getenv("OPENAI_API_KEY")

# batch by 2048
# each max 8192
df_n = df[df['n_tokens'] < 8192]


In [None]:
df_n['emb'] = df_n.groupby(np.arange(len(df_n))//2048)['tokens'].transform(lambda x: get_emb(x))

In [286]:
df_n.head()

Unnamed: 0,code,function_name,filepath,n_tokens,tokens,code_embedding
0,@Override\n protected void update(final AnAct...,update,lang-api/src/com/intellij/execution/ui/actions...,67,"[6123, 198, 220, 2682, 742, 2713, 10285, 1556,...","[-0.016020899638533592, -0.01102778222411871, ..."
1,private boolean isToFocus(final ViewContext co...,isToFocus,lang-api/src/com/intellij/execution/ui/actions...,39,"[2039, 2777, 374, 1271, 14139, 10285, 2806, 20...","[0.002539371605962515, -0.01106214802712202, 0..."
2,@Override\n protected void actionPerformed(fi...,actionPerformed,lang-api/src/com/intellij/execution/ui/actions...,63,"[6123, 198, 220, 2682, 742, 25026, 10285, 1556...","[-0.009785722009837627, -0.014558453112840652,..."
3,@Override\n public @NotNull ActionUpdateThrea...,getActionUpdateThread,lang-api/src/com/intellij/execution/ui/actions...,25,"[6123, 198, 220, 586, 571, 11250, 5703, 4387, ...","[-0.05378969386219978, -0.02131587639451027, 0..."
4,@Override\n public final void update(final @N...,update,lang-api/src/com/intellij/execution/ui/actions...,76,"[6123, 198, 220, 586, 1620, 742, 2713, 10285, ...","[-0.02128949761390686, -0.009933851659297943, ..."


In [278]:
df_n.drop(columns=['emb'], inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_n.drop(columns=['emb'], inplace=True)


In [281]:
from typing import List

from tenacity import retry, stop_after_attempt, wait_random_exponential


@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
    list_of_tokens: List[int], engine="text-similarity-babbage-001"
) -> List[List[float]]:
    assert len(list_of_tokens) <= 2048, "The batch size should not be larger than 2048."

    # replace newlines, which can negatively affect performance.
    # list_of_text = [text.replace("\n", " ") for text in list_of_text]

    data = openai.Embedding.create(input=list_of_tokens, engine=engine).data
    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
    return [d["embedding"] for d in data]


In [282]:
# df_n['emb']            = df_n.groupby(np.arange(len(df_n))//2048)['tokens'].transform(lambda x: get_emb(x.tolist()))
df_n['code_embedding'] = df_n.groupby(np.arange(len(df_n))//2048)['tokens'].transform(lambda x: get_embeddings(x.tolist(), engine=model))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_n['code_embedding'] = df_n.groupby(np.arange(len(df_n))//2048)['tokens'].transform(lambda x: get_embeddings(x.tolist(), engine=model))


In [284]:
code_emb_file = "ij-communit-platform-functions-embeddings.gz.csv"

df_n.to_csv(code_emb_file, index=False, compression='gzip')


In [289]:
type(df_n.code_embedding[0][0])

float

In [290]:
from openai.embeddings_utils import cosine_similarity
from openai.embeddings_utils import get_embedding

def search_functions(df, code_query, n=3, pprint=True, n_lines=7):
    embedding = get_embedding(code_query, engine='text-embedding-ada-002')
    df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))

    res = df.sort_values('similarities', ascending=False).head(n)
    if pprint:
        for r in res.iterrows():
            print(r[1].filepath+":"+r[1].function_name + "  score=" + str(round(r[1].similarities, 3)))
            print("\n".join(r[1].code.split("\n")[:n_lines]))
            print('-'*70)
    return res

In [294]:
res = search_functions(df_n, 'API for syntax highliting', n=5)

usageView/src/com/intellij/usages/Usage.java:highlightInEditor  score=0.831
void highlightInEditor();
----------------------------------------------------------------------
testFramework/src/com/intellij/testFramework/fixtures/CodeInsightTestFixture.java:checkHighlighting  score=0.825
long checkHighlighting();
----------------------------------------------------------------------
platform-api/src/com/intellij/openapi/options/colors/ColorSettingsPage.java:getHighlighter  score=0.823
@NotNull SyntaxHighlighter getHighlighter();
----------------------------------------------------------------------
editor-ui-ex/src/com/intellij/openapi/editor/ex/util/LayeredHighlighterIterator.java:getActiveSyntaxHighlighter  score=0.813
@NotNull
  SyntaxHighlighter getActiveSyntaxHighlighter();
----------------------------------------------------------------------
xdebugger-impl/src/com/intellij/xdebugger/impl/ui/DebuggerColorsPage.java:getHighlighter  score=0.809
@Override
  @NotNull
  public SyntaxHigh

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))


In [295]:
df[df['function_name'] == 'getNavigationOffset']

Unnamed: 0,code,function_name,filepath
1205,default int getNavigationOffset() {\n FileE...,getNavigationOffset,usageView/src/com/intellij/usages/Usage.java
1239,@Override\n public int getNavigationOffset() ...,getNavigationOffset,usageView/src/com/intellij/usages/UsageInfo2Us...
1277,int getNavigationOffset();,getNavigationOffset,usageView/src/com/intellij/usages/UsageInfoAda...
78941,public int getNavigationOffset() {\n if (my...,getNavigationOffset,core-api/src/com/intellij/usageView/UsageInfo....


In [303]:
print(df.function_name.describe())
print()
print(df.function_name.value_counts().head(50))

count              175242
unique              62280
top       actionPerformed
freq                 1986
Name: function_name, dtype: object

function_name
actionPerformed                 1986
toString                        1700
update                          1659
getActionUpdateThread           1413
dispose                         1029
equals                           995
hashCode                         919
getInstance                      900
getName                          620
run                              605
create                           552
getIcon                          528
get                              505
getText                          419
isEnabled                        404
remove                           403
apply                            401
getProject                       395
isSelected                       373
reset                            370
clear                            359
getState                         358
setSelected                     

In [292]:
df_n.head()

Unnamed: 0,code,function_name,filepath,n_tokens,tokens,code_embedding,similarities
0,@Override\n protected void update(final AnAct...,update,lang-api/src/com/intellij/execution/ui/actions...,67,"[6123, 198, 220, 2682, 742, 2713, 10285, 1556,...","[-0.016020899638533592, -0.01102778222411871, ...",0.657746
1,private boolean isToFocus(final ViewContext co...,isToFocus,lang-api/src/com/intellij/execution/ui/actions...,39,"[2039, 2777, 374, 1271, 14139, 10285, 2806, 20...","[0.002539371605962515, -0.01106214802712202, 0...",0.65743
2,@Override\n protected void actionPerformed(fi...,actionPerformed,lang-api/src/com/intellij/execution/ui/actions...,63,"[6123, 198, 220, 2682, 742, 25026, 10285, 1556...","[-0.009785722009837627, -0.014558453112840652,...",0.669121
3,@Override\n public @NotNull ActionUpdateThrea...,getActionUpdateThread,lang-api/src/com/intellij/execution/ui/actions...,25,"[6123, 198, 220, 586, 571, 11250, 5703, 4387, ...","[-0.05378969386219978, -0.02131587639451027, 0...",0.644399
4,@Override\n public final void update(final @N...,update,lang-api/src/com/intellij/execution/ui/actions...,76,"[6123, 198, 220, 586, 1620, 742, 2713, 10285, ...","[-0.02128949761390686, -0.009933851659297943, ...",0.647052


In [273]:
from typing import List

def get_emb(batch: List[List[int]]):
     print(type(batch))
     return [[len(x)] for x in batch]

x = df_n \
    .groupby(np.arange(len(df_n))//2048) \
    # .apply(list) \
    # .apply(lambda x: np.asarray(x).shape)

for k, group in x: # 86
     tokens_batch = group['tokens'].to_list()
     group['tokens_shape'] = group['tokens'].apply(list).transform(len)
     # print(tb)
     # group['tokens_shape'] = pd.DataFrame(get_emb(tokens_batch))
     print(group.head())
     break

                                                code          function_name   
0  @Override\n  protected void update(final AnAct...                 update  \
1  private boolean isToFocus(final ViewContext co...              isToFocus   
2  @Override\n  protected void actionPerformed(fi...        actionPerformed   
3  @Override\n  public @NotNull ActionUpdateThrea...  getActionUpdateThread   
4  @Override\n  public final void update(final @N...                 update   

                                            filepath  n_tokens   
0  lang-api/src/com/intellij/execution/ui/actions...        67  \
1  lang-api/src/com/intellij/execution/ui/actions...        39   
2  lang-api/src/com/intellij/execution/ui/actions...        63   
3  lang-api/src/com/intellij/execution/ui/actions...        25   
4  lang-api/src/com/intellij/execution/ui/actions...        76   

                                              tokens   emb  tokens_shape  
0  [6123, 198, 220, 2682, 742, 2713, 10285, 1556,... 

In [251]:
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")

# print(openai.Model.list())
# response = openai.Embedding.create(model=model, input=[[64], [65]])
# print(response)

# df_n['code_embedding'] = df['tokens'].apply(lambda x: get_embedding(x, engine=model))



In [None]:
embeddings = response['data'][0]['embedding']

print(embeddings)