In [1]:
import pandas as pd
import json

from beir.retrieval.search.lexical import BM25Search as BM25


from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.models import SPLADE, SentenceBERT, UniCOIL
from beir.retrieval.search.sparse import SparseSearch


from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from github_search.evaluation.beir_evaluation import EvaluateRetrievalCustom as EvaluateRetrieval, CorpusDataLoader
from beir.retrieval.search.lexical import BM25Search as BM25

import sentence_transformers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from github_search.evaluation.beir_evaluation import EvaluateRetrievalCustom as EvaluateRetrieval, CorpusDataLoader

In [3]:
import huggingface_hub
import dotenv
import os

from dotenv import load_dotenv

load_dotenv()

huggingface_hub.login(token=os.environ["HF_TOKEN"])

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [4]:
import pickle 

with open("/home/kuba/Projects/github_search/.dagster/storage/corpus_information", "rb") as f:
    corpora = json.loads(pickle.load(f))

[(cname, len(corpora[cname].keys())) for cname in corpora.keys()]

[('readme', 6771),
 ('repomap', 6771),
 ('selected_code', 6771),
 ('dependency_signature', 6771),
 ('repository_signature', 6771),
 ('generated_tasks', 6771),
 ('code2doc_generated_readme', 6771),
 ('code2doc_files_summary', 6771),
 ('repomap_code2doc_generated_readme', 6771),
 ('repomap_code2doc_files_summary', 6771),
 ('flat_code2doc_generated_readme', 6771),
 ('flat_repomap_code2doc_generated_readme', 6771),
 ('dep_sig_code2doc_generated_readme', 6771),
 ('dep_sig_code2doc_files_summary', 6771),
 ('flat_dep_sig_code2doc_generated_readme', 6771)]

In [5]:
list(corpora["flat_repomap_code2doc_generated_readme"].items())[:10]

[('0',
  {'text': '# SincNet Demo Repository\n\nThis repository contains a collection of Python scripts designed for machine learning tasks, particularly focusing on speech processing and recognition. It includes functionalities for preparing datasets, defining neural network models, performing inference, measuring similarity, and speaker identification.\n\n## Key Functionalities\n- **Data Preparation**: Functions like `ReadList` in `data_io.py` are used to read and prepare lists of audio files.\n- **Batch Creation**: The function `create_batches_rnd` in both `data_io.py` and `speaker_id.py` is utilized for creating randomized batches of',
   'title': '008karan/SincNet_demo'}),
 ('1',
  {'text': 'The repository focuses on developing and testing machine learning models, specifically multi-agent reinforcement learning (MARL) systems. It uses a custom environment with multiple agents that can interact through various actions and perceive their surroundings using observations. The data use

In [6]:
corpora.keys()

dict_keys(['readme', 'repomap', 'selected_code', 'dependency_signature', 'repository_signature', 'generated_tasks', 'code2doc_generated_readme', 'code2doc_files_summary', 'repomap_code2doc_generated_readme', 'repomap_code2doc_files_summary', 'flat_code2doc_generated_readme', 'flat_repomap_code2doc_generated_readme', 'dep_sig_code2doc_generated_readme', 'dep_sig_code2doc_files_summary', 'flat_dep_sig_code2doc_generated_readme'])

In [7]:
def get_repos_for_query(query, repos_df):
    return repos_df[repos_df["tasks"].apply(lambda ts: query in ts)]


def get_queries(repos_df, min_query_count):
    all_queries = repos_df["query_tasks"].explode()
    qcounts = all_queries.value_counts()
    return qcounts[qcounts >= min_query_count].index.to_list()

def prepare_query_data(repos_df, min_query_count=5):
    task_queries = {str(i): query for (i, query) in enumerate(get_queries(repos_df, min_query_count=min_query_count))}

    task_qrels = {
        qid: {str(corpus_id): 1 for corpus_id in get_repos_for_query(task_queries[qid], repos_df).index}
        for qid in task_queries.keys()
    }
    return task_queries, task_qrels

In [8]:
with open("/home/kuba/Projects/github_search/.dagster/storage/repos_with_representations_df", "rb") as f:
    sampled_repos_df = pickle.load(f)


repos_sorted = [rec["title"] for rec in list(corpora["readme"].values())]
sampled_repos_df = pd.Series(repos_sorted, name="repo").reset_index().merge(sampled_repos_df, on="repo")
task_queries, task_qrels = prepare_query_data(sampled_repos_df, min_query_count=10)

In [9]:
len(task_queries.values())

742

In [10]:
pd.Series(map(len, task_qrels.values())).describe()

count    742.000000
mean      23.827493
std       30.461274
min       10.000000
25%       11.000000
50%       13.000000
75%       22.000000
max      271.000000
dtype: float64

In [11]:
sampled_repos_df = sampled_repos_df[sampled_repos_df["tasks"].apply(len) <= 10]

with open("../output/elasticsearch/queries_qrels.json", "w") as f:
    json.dump({"task_queries": task_queries, "task_qrels": task_qrels}, f)

In [12]:
for cid in corpora["readme"].keys():
    assert corpora["readme"][cid]["title"] == corpora["readme"][cid]["title"], f"no match at {cid}"
    #assert corpora["readme"][cid]["title"] == corpora[("dependency_signature", 0)][cid]["title"], f"no match at {cid}"

In [13]:
## Checking elasticsearch

In [14]:
import elasticsearch

es_client = elasticsearch.Elasticsearch()
def retrieve_repos_with_es(query, k=50, index="readme", es_client=es_client):
    es_result = es_client.search(index=index, body={"query": {"match": {"txt": query}}}, size=k)
    return [
        hit["_source"]["title"]
        for hit in es_result["hits"]["hits"]
    ]



def get_elasticsearch_results():
    retrieved_repo_tasks = {}

    qcounts = sampled_repos_df["tasks"].explode().value_counts()
    used_queries = [
        query
        for query in sampled_repos_df["tasks"].explode().drop_duplicates()
        if qcounts.loc[query] > 5
    ]
    # [task_queries[qid] for qid in task_queries.keys()]
    
    index="selected_code"
    for query in used_queries:
        retrieved_tasks = sampled_repos_df[sampled_repos_df["repo"].isin(retrieve_repos_with_es(query, index=index))]["tasks"].to_list()
        retrieved_repo_tasks[query] = retrieved_tasks
    
    k = 10
    query_hits = pd.Series({
        query: sum([query in tasks for tasks in retrieved_repo_tasks[query][:k]])
        for query in retrieved_repo_tasks.keys()
    })

def show_elasticsearch_results(qid='10'):
    query = task_queries[qid]
    
    print(query)
    print(query_hits[query], "hits")
    
    for hit in es_client.search(index=index, body={"query": {"match": {"txt": task_queries[qid]}}}, size=k)["hits"]["hits"]:
        print("#" * 100)
        print("#" * 100)
        repo_name = hit["_source"]["title"]
        repo_record = sampled_repos_df[sampled_repos_df["repo"] == repo_name].iloc[0]
        is_hit = query in repo_record["tasks"]
        print(repo_name, "HIT" if is_hit else "NO HIT")
        
        if is_hit:
            print("#" * 100)
            print("#" * 100)
            print(hit['_source']['txt'])

## Evaluating with BEIR

In [15]:
def load_w2v_sentence_transformer(w2v_model_path):
    w2v_layer = sentence_transformers.models.WordEmbeddings.load(w2v_model_path)
    model = sentence_transformers.SentenceTransformer(modules=[w2v_layer, sentence_transformers.models.Pooling(200)])
    model.max_seq_length = 2048
    return model


def load_sentence_bert(model_name):
    st_model = SentenceBERT("sentence-transformers/all-mpnet-base-v2")
    st_model.doc_model = sentence_transformers.SentenceTransformer(model_name, trust_remote_code=True)
    st_model.q_model = st_model.doc_model
    return st_model

def get_w2v_retriever(w2v_model_path="../models/rnn_abstract_readme_w2v/0_WordEmbeddings"):
    w2v_model = load_w2v_sentence_transformer(w2v_model_path)
    st_model = SentenceBERT("sentence-transformers/all-mpnet-base-v2")
    st_model.q_model = w2v_model
    st_model.doc_model = w2v_model
    return EvaluateRetrieval(DRES(st_model), score_function="cos_sim")

def get_splade_retriever(splade_model_path = "splade/weights/distilsplade_max", batch_size=128):
    splade_model = DRES(SPLADE(splade_model_path), batch_size=128)
    return EvaluateRetrieval(splade_model, score_function="dot")


def get_bm25_retrievers(corpora):
    def sanitize_index_name(index_name):
        if type(index_name) is str:
            return index_name
        else:
            return "".join(map(str, index_name))
    
    bm25_retrievers = {}
    for corpus_name, corpus in corpora.items():
        model = BM25(index_name=sanitize_index_name(corpus_name))
        retriever = EvaluateRetrieval(model)
        bm25_retrievers[corpus_name] = retriever
    return bm25_retrievers


sentence_transformer_model_names = [
    "sentence-transformers/static-retrieval-mrl-en-v1",
    "sentence-transformers/all-mpnet-base-v2",
    "sentence-transformers/all-MiniLM-L12-v2",
    #"google/embeddinggemma-300m"
    #"nomic-ai/modernbert-embed-base",
    
    #"estrogen/ModernBERT-base-nli-v3"
    #"BAAI/bge-large-en-v1.5",
    #"mixedbread-ai/mxbai-embed-large-v1"
]

def get_sentence_transformer_retriever(model_name="sentence-transformers/all-mpnet-base-v2", batch_size=8):
    model = DRES(load_sentence_bert(model_name), batch_size=batch_size)
    return EvaluateRetrieval(model, score_function="cos_sim")

def get_unicoil_retriever(model_name="castorini/unicoil-msmarco-passage"):
    """
    THERE IS A BUG WITH BEIR THAT MAKES THIS UNUSABLE
    """
    model = SparseSearch(UniCOIL(model_path=model_name), batch_size=32)
    return EvaluateRetrieval(model, score_function="dot")

In [16]:
corpora.keys()

dict_keys(['readme', 'repomap', 'selected_code', 'dependency_signature', 'repository_signature', 'generated_tasks', 'code2doc_generated_readme', 'code2doc_files_summary', 'repomap_code2doc_generated_readme', 'repomap_code2doc_files_summary', 'flat_code2doc_generated_readme', 'flat_repomap_code2doc_generated_readme', 'dep_sig_code2doc_generated_readme', 'dep_sig_code2doc_files_summary', 'flat_dep_sig_code2doc_generated_readme'])

In [17]:
def get_corpus_samples(corpora, n_repos=10):
    records = []


    for k in range(n_repos):
        for cname in corpora.keys():
            if type(cname) is tuple:
                if 0 in cname:
                    display_name = cname[0]
                else:
                    continue
            else:
                display_name = cname
            record = corpora[cname][str(k)]
            record["corpus"] = display_name
            records.append(record)
    
    return pd.DataFrame.from_records(records).rename(columns = {"title": "repo_name", "corpus": "representation"}).fillna(method="ffill")

In [18]:
def get_repomaps_df(repo_names, repomap_path="../output/aider/selected_repo_maps_1024.json"):
    with open(repomap_path) as f:
        repomaps = json.load(f)

    records = []
    for repo in repo_names:
        records.append({"repo_name": repo, "text": repomaps[repo], "representation": "repomap"})
    return pd.DataFrame.from_records(records)

from pylate import indexes, models, retrieve


class PyLateBEIRWrapper:

    def __init__(self, model_name="lightonai/colbertv2.0"):
        
        self.model = models.ColBERT(
            model_name_or_path=model_name,
        )
        self.index = indexes.Voyager(
            index_folder=f"../output/pylate-index/{model_name}",
            index_name="index",
            override=True,
        )
        self.retriever = None

    def index_corpus(self, corpus):
        documents = corpus.values()
        documents_embeddings = self.model.encode(
            documents,
            batch_size=32,
            is_query=False, # Encoding documents
            show_progress_bar=True,
        )
        
        # Add the documents ids and embeddings to the Voyager index
        self.index.add_documents(
            documents_ids=corpus.keys(),
            documents_embeddings=documents_embeddings,
        )
        self.retriever = retrieve.ColBERT(index=self.index)

    def retrieve(self, query):
        return self.retriever.retrieve(query)

pylate_model = PyLateBEIRWrapper()

In [19]:
import sentence_transformers

In [20]:
w2v_retriever = get_w2v_retriever()

In [21]:
w2v_retriever

<github_search.evaluation.beir_evaluation.EvaluateRetrievalCustom at 0x72f7496d8dd0>

In [22]:
#splade_retriever = get_splade_retriever() 

# change sentence-transformers to 2.7?
sentence_transformer_retrievers = {
    model_name: get_sentence_transformer_retriever(model_name)
    for model_name in sentence_transformer_model_names
}

In [23]:
bm25_retrievers = get_bm25_retrievers(corpora)

## Per query results

In [24]:

from pydantic import BaseModel
from typing import Dict

class RetrieverInput(BaseModel):
    corpus: Dict[str, dict]
    queries: Dict[str, str]
    qrels: Dict[str, Dict[str, int]]


class RetrievalEvaluationResults(BaseModel):
    retrieval_results: Dict[str, Dict[str, float]]
    metrics: dict
    model_type: str

    @classmethod
    def from_retriever(cls, retriever, retriever_input, metric_names=["accuracy@k", "hits@k", "r_cap@k", "mrr@k"]):
        retrieval_results = retriever.retrieve(retriever_input.corpus, retriever_input.queries)
        custom_metrics = retriever.evaluate_custom_multi(retriever_input.qrels, retrieval_results, retriever.k_values, metrics=metric_names)
        other_metrics = retriever.evaluate(retriever_input.qrels, retrieval_results, retriever.k_values, ignore_identical_ids=False)
        metrics = custom_metrics | cls.tuple_to_dict(other_metrics)
        try:
            model_type = str(retriever.retriever.model)
        except:
            model_type = "bm25"
        return RetrievalEvaluationResults(metrics=metrics, model_type=model_type, retrieval_results=retrieval_results)


    @classmethod
    def tuple_to_dict(cls, dicts):
        merged_dict = {}
        for d in dicts:
            merged_dict = d | merged_dict
        return merged_dict




In [25]:
retriever_inputs = {
    corpus_name: RetrieverInput(corpus=corpus, queries=task_queries, qrels=task_qrels)
    for (corpus_name, corpus) in corpora.items()
}

In [26]:
from github_search.evaluation.beir_evaluation import PerQueryIREvaluator

In [27]:
per_query_evaluator = PerQueryIREvaluator(k_values=[1, 5, 10, 25])

In [28]:
retriever_inputs = {
    corpus_name: RetrieverInput(corpus=corpus, queries=task_queries, qrels=task_qrels)
    for (corpus_name, corpus) in corpora.items()
}

In [29]:
retriever_inputs.keys()

dict_keys(['readme', 'repomap', 'selected_code', 'dependency_signature', 'repository_signature', 'generated_tasks', 'code2doc_generated_readme', 'code2doc_files_summary', 'repomap_code2doc_generated_readme', 'repomap_code2doc_files_summary', 'flat_code2doc_generated_readme', 'flat_repomap_code2doc_generated_readme', 'dep_sig_code2doc_generated_readme', 'dep_sig_code2doc_files_summary', 'flat_dep_sig_code2doc_generated_readme'])

In [30]:
named_retrievers = {
    corpus_name: [
        ("bm25", bm25_retrievers[corpus_name]),
        ("word2vec", w2v_retriever),
    ] + list(sentence_transformer_retrievers.items())
    for corpus_name in retriever_inputs.keys()
}

In [31]:
rc = named_retrievers["readme"][1][1]

In [32]:
retriever_inputs.keys()

dict_keys(['readme', 'repomap', 'selected_code', 'dependency_signature', 'repository_signature', 'generated_tasks', 'code2doc_generated_readme', 'code2doc_files_summary', 'repomap_code2doc_generated_readme', 'repomap_code2doc_files_summary', 'flat_code2doc_generated_readme', 'flat_repomap_code2doc_generated_readme', 'dep_sig_code2doc_generated_readme', 'dep_sig_code2doc_files_summary', 'flat_dep_sig_code2doc_generated_readme'])

In [33]:
%%time
per_query_results = {
    (corpus_name, retriever_name): per_query_evaluator.get_scores(retriever=retriever, ir_data=retriever_inputs[corpus_name])
    for corpus_name in retriever_inputs.keys()
    for (retriever_name, retriever) in named_retrievers[corpus_name]
}

raw_per_query_results_df = pd.concat([
    df.assign(retriever=[retriever_name]*len(df)).assign(corpus=[corpus_name]*len(df))
    for ((corpus_name, retriever_name), df) in per_query_results.items()
])

per_query_results_df = raw_per_query_results_df.assign(
    corpus=raw_per_query_results_df["corpus"].apply(lambda cname: cname if type(cname) is str else cname[0]),
    generation=raw_per_query_results_df["corpus"].apply(lambda cname: 0 if type(cname) is str else cname[1])
)

per_query_results_df = (
    per_query_results_df
        .groupby(["query", "retriever", "corpus"]).agg("mean").drop(columns=["generation"])
        .reset_index()
)


  0%|                                                                                                                                                                                   | 0/6771 [00:00<?, ?docs/s]
que: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:03<00:00,  1.64it/s]
Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 108.73it/s]
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:01<00:00, 36.60it/s]
Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

CPU times: user 40min 55s, sys: 1min 25s, total: 42min 20s
Wall time: 17min 1s


In [34]:

per_query_results_df

per_query_results_df.to_csv("../results/per_query_ir_results.csv", index=False)

(per_query_results_df
    .drop(columns=["query"])
    .groupby(["corpus", "retriever"])
    .agg("mean").reset_index(drop=False)
    .sort_values("Accuracy@10")
)[["corpus", "retriever", "Precision@10", "Accuracy@10"]]

Unnamed: 0,corpus,retriever,Precision@10,Accuracy@10
54,repomap,word2vec,0.015768,0.130728
24,dependency_signature,word2vec,0.024394,0.160377
74,selected_code,word2vec,0.027763,0.203504
73,selected_code,sentence-transformers/static-retrieval-mrl-en-v1,0.072102,0.392183
53,repomap,sentence-transformers/static-retrieval-mrl-en-v1,0.070889,0.397574
...,...,...,...,...
32,flat_dep_sig_code2doc_generated_readme,sentence-transformers/all-mpnet-base-v2,0.345013,0.885445
26,flat_code2doc_generated_readme,sentence-transformers/all-MiniLM-L12-v2,0.344609,0.886792
45,readme,bm25,0.422776,0.913747
46,readme,sentence-transformers/all-MiniLM-L12-v2,0.403504,0.919137


In [35]:

per_query_results_df

per_query_results_df.to_csv("../results/per_query_ir_results.csv", index=False)

(per_query_results_df
    .drop(columns=["query"])
    .groupby(["corpus", "retriever"])
    .agg("mean").reset_index(drop=False)
    .sort_values("Accuracy@10")
)[["corpus", "retriever", "Precision@10", "Accuracy@10"]]

Unnamed: 0,corpus,retriever,Precision@10,Accuracy@10
54,repomap,word2vec,0.015768,0.130728
24,dependency_signature,word2vec,0.024394,0.160377
74,selected_code,word2vec,0.027763,0.203504
73,selected_code,sentence-transformers/static-retrieval-mrl-en-v1,0.072102,0.392183
53,repomap,sentence-transformers/static-retrieval-mrl-en-v1,0.070889,0.397574
...,...,...,...,...
32,flat_dep_sig_code2doc_generated_readme,sentence-transformers/all-mpnet-base-v2,0.345013,0.885445
26,flat_code2doc_generated_readme,sentence-transformers/all-MiniLM-L12-v2,0.344609,0.886792
45,readme,bm25,0.422776,0.913747
46,readme,sentence-transformers/all-MiniLM-L12-v2,0.403504,0.919137


## Aggregated results

In [None]:
for corpus_name in corpora.keys():
    try:
        RetrievalEvaluationResults.from_retriever(bm25_retrievers[corpus_name], retriever_inputs[corpus_name])
    except:
        print(corpus_name)

In [None]:
bm25_results = {
    corpus_name: RetrievalEvaluationResults.from_retriever(bm25_retrievers[corpus_name], retriever_inputs[corpus_name])
    for corpus_name in corpora.keys()
}

splade_results = {
    corpus_name: RetrievalEvaluationResults.from_retriever(splade_retriever, retriever_inputs[corpus_name])
    for corpus_name in corpora.keys()
}

In [None]:
word2vec_results = {
    corpus_name: RetrievalEvaluationResults.from_retriever(w2v_retriever, retriever_inputs[corpus_name])
    for corpus_name in corpora.keys()
}

In [None]:
sentence_transformer_results = {
    (corpus_name, model_name.split("/")[1]): RetrievalEvaluationResults.from_retriever(sentence_transformer_retrievers[model_name], retriever_inputs[corpus_name])
    for corpus_name in corpora.keys()
    for model_name in sentence_transformer_model_names
}

In [None]:
bm25_metrics = [
    {"corpus": corpus_name, "retriever": "bm25", **bm25_results[corpus_name].metrics}
    for corpus_name in corpora.keys()
]

In [None]:
word2vec_metrics = [
    {"corpus": corpus_name, "retriever": "Python code word2vec", **word2vec_results[corpus_name].metrics}
    for corpus_name in corpora.keys()
]

In [None]:
#splade_metrics = [
#    {"corpus": corpus_name, "retriever": "splade", **splade_results[corpus_name].metrics}
#     for corpus_name in corpora.keys()
#]
 
sentence_transformer_metrics = [
    {"corpus": corpus_name, "retriever": f"{model_name} (sentence_transformer)", **sentence_transformer_results[(corpus_name, model_name)].metrics}
    for (corpus_name, model_name) in sentence_transformer_results.keys()
]

all_metrics_df = pd.DataFrame.from_records(bm25_metrics + word2vec_metrics +  sentence_transformer_metrics).sort_values("Hits@10", ascending=False)

In [None]:
all_metrics_df = all_metrics_df[~all_metrics_df["corpus"].isin(["flat_repomap_code2doc_files_summary", "flat_code2doc_files_summary"])]

In [None]:
all_metrics_df.shape

In [None]:
all_metrics_df["corpus"].unique()

In [None]:
print(corpora.keys())

In [None]:
pd.options.display.max_rows = 999

In [None]:
all_metrics_df[["corpus", "retriever", "Accuracy@10", "P@10"]].sort_values("P@10", ascending=False)

In [None]:
all_metrics_df.columns

In [None]:
model_name = "qwen2.5:7b-instruct"

In [None]:
all_metrics_df.to_csv(f"../output/code2doc/beir_results_{model_name}.csv", index=False)

In [91]:
#all_metrics_df.to_csv(f"../output/code2doc/beir_results_with_modernbert_{model_name}.csv", index=False)

## Results

By default we will use min_task_count=10 (as we used originally)

We can switch to smaller task counts like 5 to incorporate the fact that we use sample of repos

In [92]:
metric_df_cols = ["corpus", "retriever", "Accuracy@10", "Hits@10", "R_cap@10", "P@1", "P@5", "P@10"]

In [93]:
all_metrics_df[metric_df_cols]

Unnamed: 0,corpus,retriever,Accuracy@10,Hits@10,R_cap@10,P@1,P@5,P@10
29,readme,all-mpnet-base-v2 (sentence_transformer),0.92857,4.25202,0.42507,0.64151,0.51132,0.42507
0,readme,bm25,0.91375,4.22911,0.42278,0.59973,0.50081,0.42278
30,readme,all-MiniLM-L12-v2 (sentence_transformer),0.91914,4.03639,0.4035,0.66173,0.51482,0.4035
69,flat_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.88005,3.48922,0.34879,0.52426,0.42965,0.34879
70,flat_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.88679,3.44744,0.34461,0.54717,0.43423,0.34461
77,flat_repomap_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.87062,3.30863,0.33073,0.51482,0.41294,0.33073
78,flat_repomap_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.87466,3.25606,0.32547,0.53235,0.40889,0.32547
53,code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.8558,3.21429,0.32129,0.52022,0.40162,0.32129
54,code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.86253,3.15768,0.31563,0.51887,0.4,0.31563
61,repomap_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.84232,3.12938,0.3128,0.49326,0.39057,0.3128


In [94]:
all_metrics_df[metric_df_cols].sort_values("Accuracy@10", ascending=False)

Unnamed: 0,corpus,retriever,Accuracy@10,Hits@10,R_cap@10,P@1,P@5,P@10
29,readme,all-mpnet-base-v2 (sentence_transformer),0.92857,4.25202,0.42507,0.64151,0.51132,0.42507
30,readme,all-MiniLM-L12-v2 (sentence_transformer),0.91914,4.03639,0.4035,0.66173,0.51482,0.4035
0,readme,bm25,0.91375,4.22911,0.42278,0.59973,0.50081,0.42278
70,flat_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.88679,3.44744,0.34461,0.54717,0.43423,0.34461
69,flat_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.88005,3.48922,0.34879,0.52426,0.42965,0.34879
78,flat_repomap_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.87466,3.25606,0.32547,0.53235,0.40889,0.32547
77,flat_repomap_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.87062,3.30863,0.33073,0.51482,0.41294,0.33073
54,code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.86253,3.15768,0.31563,0.51887,0.4,0.31563
53,code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.8558,3.21429,0.32129,0.52022,0.40162,0.32129
10,flat_code2doc_generated_readme,bm25,0.84906,3.10782,0.31065,0.52432,0.39081,0.31162


In [95]:
all_metrics_df.groupby("corpus").apply(lambda df: df.sort_values("Accuracy@10", ascending=False).iloc[0])[metric_df_cols].sort_values("Accuracy@10", ascending=False)

  all_metrics_df.groupby("corpus").apply(lambda df: df.sort_values("Accuracy@10", ascending=False).iloc[0])[metric_df_cols].sort_values("Accuracy@10", ascending=False)


Unnamed: 0_level_0,corpus,retriever,Accuracy@10,Hits@10,R_cap@10,P@1,P@5,P@10
corpus,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
readme,readme,all-mpnet-base-v2 (sentence_transformer),0.92857,4.25202,0.42507,0.64151,0.51132,0.42507
flat_code2doc_generated_readme,flat_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.88679,3.44744,0.34461,0.54717,0.43423,0.34461
flat_repomap_code2doc_generated_readme,flat_repomap_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.87466,3.25606,0.32547,0.53235,0.40889,0.32547
code2doc_generated_readme,code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.86253,3.15768,0.31563,0.51887,0.4,0.31563
repomap_code2doc_generated_readme,repomap_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.84636,3.06739,0.3066,0.52022,0.38679,0.3066
generated_tasks,generated_tasks,all-mpnet-base-v2 (sentence_transformer),0.82749,2.78571,0.27844,0.47305,0.34987,0.27844
code2doc_files_summary,code2doc_files_summary,bm25,0.80458,2.469,0.24677,0.46081,0.32486,0.2473
repomap_code2doc_files_summary,repomap_code2doc_files_summary,bm25,0.80054,2.43531,0.2434,0.43708,0.31475,0.24384
repository_signature,repository_signature,all-mpnet-base-v2 (sentence_transformer),0.78571,2.28167,0.22803,0.41644,0.29623,0.22803
dependency_signature,dependency_signature,all-mpnet-base-v2 (sentence_transformer),0.7372,1.91375,0.19124,0.40566,0.26038,0.19124


In [97]:
all_metrics_df.groupby("retriever").apply(lambda df: df.sort_values("Accuracy@10", ascending=False).iloc[0])[metric_df_cols].sort_values("Accuracy@10", ascending=False)

  all_metrics_df.groupby("retriever").apply(lambda df: df.sort_values("Accuracy@10", ascending=False).iloc[0])[metric_df_cols].sort_values("Accuracy@10", ascending=False)


Unnamed: 0_level_0,corpus,retriever,Accuracy@10,Hits@10,R_cap@10,P@1,P@5,P@10
retriever,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
all-mpnet-base-v2 (sentence_transformer),readme,all-mpnet-base-v2 (sentence_transformer),0.92857,4.25202,0.42507,0.64151,0.51132,0.42507
all-MiniLM-L12-v2 (sentence_transformer),readme,all-MiniLM-L12-v2 (sentence_transformer),0.91914,4.03639,0.4035,0.66173,0.51482,0.4035
bm25,readme,bm25,0.91375,4.22911,0.42278,0.59973,0.50081,0.42278
static-retrieval-mrl-en-v1 (sentence_transformer),flat_code2doc_generated_readme,static-retrieval-mrl-en-v1 (sentence_transformer),0.79245,2.39353,0.23922,0.42992,0.30404,0.23922
Python code word2vec,generated_tasks,Python code word2vec,0.67385,1.7372,0.17358,0.30458,0.21348,0.17426
embeddinggemma-300m (sentence_transformer),generated_tasks,embeddinggemma-300m (sentence_transformer),0.18194,0.26146,0.02601,0.0566,0.03235,0.02601


In [98]:
all_metrics_df[all_metrics_df["retriever"] == "bm25"][metric_df_cols]

Unnamed: 0,corpus,retriever,Accuracy@10,Hits@10,R_cap@10,P@1,P@5,P@10
0,readme,bm25,0.91375,4.22911,0.42278,0.59973,0.50081,0.42278
10,flat_code2doc_generated_readme,bm25,0.84906,3.10782,0.31065,0.52432,0.39081,0.31162
12,flat_repomap_code2doc_generated_readme,bm25,0.84501,2.94879,0.29474,0.52973,0.37216,0.29568
6,code2doc_generated_readme,bm25,0.83288,2.77493,0.27736,0.50271,0.35881,0.27859
8,repomap_code2doc_generated_readme,bm25,0.83288,2.74124,0.27399,0.50338,0.34804,0.27564
7,code2doc_files_summary,bm25,0.80458,2.469,0.24677,0.46081,0.32486,0.2473
9,repomap_code2doc_files_summary,bm25,0.80054,2.43531,0.2434,0.43708,0.31475,0.24384
5,generated_tasks,bm25,0.77493,2.37197,0.23706,0.45306,0.30966,0.24
4,repository_signature,bm25,0.75876,2.20081,0.21995,0.42954,0.28645,0.22127
2,selected_code,bm25,0.68598,1.82345,0.18221,0.37669,0.24472,0.18306


In [99]:
len(task_queries)

742

In [100]:
all_metrics_df.shape

(72, 50)

In [101]:
# task count = 5

In [102]:
all_metrics_df[["corpus", "retriever", "Accuracy@10"]].sort_values("Accuracy@10", ascending=False)

Unnamed: 0,corpus,retriever,Accuracy@10
29,readme,all-mpnet-base-v2 (sentence_transformer),0.92857
30,readme,all-MiniLM-L12-v2 (sentence_transformer),0.91914
0,readme,bm25,0.91375
70,flat_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.88679
69,flat_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.88005
78,flat_repomap_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.87466
77,flat_repomap_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.87062
54,code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.86253
53,code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.8558
10,flat_code2doc_generated_readme,bm25,0.84906


In [65]:
# task count = 10

In [66]:
all_metrics_df[["corpus", "retriever", "Accuracy@10"]].sort_values("Accuracy@10", ascending=False)

Unnamed: 0,corpus,retriever,Accuracy@10
29,readme,all-mpnet-base-v2 (sentence_transformer),0.92857
30,readme,all-MiniLM-L12-v2 (sentence_transformer),0.91914
0,readme,bm25,0.91375
73,flat_code2doc_files_summary,all-mpnet-base-v2 (sentence_transformer),0.89353
70,flat_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.88679
69,flat_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.88005
74,flat_code2doc_files_summary,all-MiniLM-L12-v2 (sentence_transformer),0.87871
78,flat_repomap_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.87466
77,flat_repomap_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.87062
54,code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.86253


In [67]:
all_metrics_df.groupby("retriever")["Accuracy@10"].agg("mean").sort_values()

retriever
embeddinggemma-300m (sentence_transformer)           0.065942
Python code word2vec                                 0.487774
static-retrieval-mrl-en-v1 (sentence_transformer)    0.637659
bm25                                                 0.788120
all-MiniLM-L12-v2 (sentence_transformer)             0.806796
all-mpnet-base-v2 (sentence_transformer)             0.817964
Name: Accuracy@10, dtype: float64

In [68]:
all_metrics_df.groupby("retriever")["Accuracy@10"].agg("mean").sort_values()

retriever
embeddinggemma-300m (sentence_transformer)           0.065942
Python code word2vec                                 0.487774
static-retrieval-mrl-en-v1 (sentence_transformer)    0.637659
bm25                                                 0.788120
all-MiniLM-L12-v2 (sentence_transformer)             0.806796
all-mpnet-base-v2 (sentence_transformer)             0.817964
Name: Accuracy@10, dtype: float64

In [69]:
all_metrics_df.groupby("retriever")["Accuracy@10"].agg("mean").sort_values()

retriever
embeddinggemma-300m (sentence_transformer)           0.065942
Python code word2vec                                 0.487774
static-retrieval-mrl-en-v1 (sentence_transformer)    0.637659
bm25                                                 0.788120
all-MiniLM-L12-v2 (sentence_transformer)             0.806796
all-mpnet-base-v2 (sentence_transformer)             0.817964
Name: Accuracy@10, dtype: float64

In [70]:
all_metrics_df.groupby("corpus")["Accuracy@10"].agg("mean").sort_values()

corpus
repomap                                   0.432167
dependency_signature                      0.452158
selected_code                             0.452828
repomap_code2doc_files_summary            0.568285
repository_signature                      0.574797
code2doc_files_summary                    0.578168
repomap_code2doc_generated_readme         0.639713
code2doc_generated_readme                 0.650495
flat_repomap_code2doc_files_summary       0.663970
flat_repomap_code2doc_generated_readme    0.666892
generated_tasks                           0.671607
flat_code2doc_generated_readme            0.680143
readme                                    0.685983
flat_code2doc_files_summary               0.692723
Name: Accuracy@10, dtype: float64

In [71]:
sampled_repos_df["tasks"].explode().value_counts().loc[list(task_queries.values())]

tasks
image classification                                         271
representation learning                                      252
frame                                                        246
question answering                                           223
transfer learning                                            219
language modelling                                           217
machine translation                                          177
image generation                                             175
data augmentation                                            173
classification                                               164
domain adaptation                                            164
time series                                                  157
super resolution                                             157
word embeddings                                              154
pose estimation                                              147
denoising          

# Final metrics

In [80]:
all_metrics_df[["corpus", "retriever", "Accuracy@10", "P@10"]].sort_values("Accuracy@10", ascending=False)

Unnamed: 0,corpus,retriever,Accuracy@10,P@10
29,readme,all-mpnet-base-v2 (sentence_transformer),0.92857,0.42507
30,readme,all-MiniLM-L12-v2 (sentence_transformer),0.91914,0.4035
0,readme,bm25,0.91375,0.42278
73,flat_code2doc_files_summary,all-mpnet-base-v2 (sentence_transformer),0.89353,0.32466
70,flat_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.88679,0.34461
69,flat_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.88005,0.34879
74,flat_code2doc_files_summary,all-MiniLM-L12-v2 (sentence_transformer),0.87871,0.32412
78,flat_repomap_code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.87466,0.32547
77,flat_repomap_code2doc_generated_readme,all-mpnet-base-v2 (sentence_transformer),0.87062,0.33073
54,code2doc_generated_readme,all-MiniLM-L12-v2 (sentence_transformer),0.86253,0.31563


## Does combining rationale with generated readme help?

It seems that the best sentence transformer retrievers can only get worse when using any other information!

In [66]:
sentence_transformer_results.keys()

dict_keys([('readme', 'all-mpnet-base-v2'), ('readme', 'all-MiniLM-L12-v2'), ('code2doc_generated_readme', 'all-mpnet-base-v2'), ('code2doc_generated_readme', 'all-MiniLM-L12-v2'), ('selected_code', 'all-mpnet-base-v2'), ('selected_code', 'all-MiniLM-L12-v2'), ('code2doc_reasoning', 'all-mpnet-base-v2'), ('code2doc_reasoning', 'all-MiniLM-L12-v2'), ('code2doc_generation_context', 'all-mpnet-base-v2'), ('code2doc_generation_context', 'all-MiniLM-L12-v2'), ('dependency_signature', 'all-mpnet-base-v2'), ('dependency_signature', 'all-MiniLM-L12-v2'), ('repository_signature', 'all-mpnet-base-v2'), ('repository_signature', 'all-MiniLM-L12-v2'), ('generated_tasks', 'all-mpnet-base-v2'), ('generated_tasks', 'all-MiniLM-L12-v2')])

In [67]:
st_generated_readme_results= sentence_transformer_results[('generated_readme', 'all-mpnet-base-v2')].retrieval_results
st_rationale_results = sentence_transformer_results[('generated_rationale', 'all-mpnet-base-v2')].retrieval_results
bm25_generated_readme_results = bm25_results["generated_readme"].retrieval_results
st_context_results = sentence_transformer_results[('generation_context', 'all-mpnet-base-v2')].retrieval_results

KeyError: ('generated_readme', 'all-mpnet-base-v2')

In [None]:
len(list(bm25_generated_readme_results.keys()))

In [None]:
len(list(st_generated_readme_results.keys()))

In [None]:
def merge_qrels(qrels1, qrels2):
    merged_qrels = {}
    for k in qrels1.keys():
        tmp_rel = dict()
        for rel_k in set(qrels1[k].keys()).union(qrels2[k]):
            tmp_rel[rel_k] = qrels1[k].get(rel_k, 0) +  qrels2[k].get(rel_k, 0)
        merged_qrels[k] = tmp_rel
    return merged_qrels

In [None]:
st_generation_results = merge_qrels(bm25_generated_readme_results, st_generated_readme_results)

In [None]:
EvaluateRetrieval().evaluate_custom(task_qrels, st_generation_results, metric="acc", k_values=[1,5,10])

In [None]:
EvaluateRetrieval().evaluate_custom(task_qrels, st_generated_readme_results, metric="acc", k_values=[1,5,10])

In [None]:
EvaluateRetrieval().evaluate_custom(task_qrels, st_rationale_results, metric="acc", k_values=[1,5,10])

In [None]:
all_metrics_df[all_metrics_df["retriever"] == "bm25"][["corpus", "retriever", "Accuracy@10"]].sort_values("Accuracy@10")

In [None]:
Splitting does not make much sense as the most of generated data is under the sentence-transformer context length (384 tokens)

In [None]:
def split_corpus_by_lengths(corpus, chunk_length):
    splitted_corpora = [dict() for _ in range(n_splits)]
    for c_id in corpus.keys():
        text = corpus[c_id]["text"]
        chunk_length =  len(text) // n_splits
        for i in range(0, n_splits):
            splitted_corpora[i] = text[i*chunk_length:(i+1)*chunk_length]
        

In [None]:
class MultiTextEvaluator(BaseModel):
    """
    Evaluate a dataframe that has multiple texts for each query (multiple generation experiments)
    iteration_col says which experiment it was
    """
    iteration_col: str
    text_cols: List[str]
    k_values: List[int] = [1,5,10,25]

    def get_ir_datas(self, df):
        for iter in df[self.iteration_col].unique():
            ir_data = load_ir_data(df[df[self.iteration_col] == iter], self.text_cols)
            yield (iter, ir_data)

    def evaluate(self, df, retriever):
        ir_datas = dict(self.get_ir_datas(df))
        dfs = []
        for iter, ir_data in ir_datas.items():
            per_query_evaluator = PerQueryIREvaluator(k_values=self.k_values)
            df = per_query_evaluator.get_scores(ir_data, retriever)
            df[self.iteration_col] = iter
            dfs.append(df)
        metrics_df = pd.concat(dfs)
        metrics_df["query"] = metrics_df.index
        return metrics_df