## Create client

In [1]:
import cohere
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

client = cohere.ClientV2(os.getenv("COHERE_API_KEY"))

## Get the query string 

In [31]:
import pandas as pd
from pprint import pprint

# Load the dataset
df = pd.read_json("data/dataset/100/nontrivial.jsonl", lines=True)
query = df.iloc[29].sent_no_cit
query_pubdate = df.iloc[29].pubdate
target_doi = df.iloc[29].citation_dois[0]
pprint(f"Query: {query}")
pprint(f"Query pubdate: {query_pubdate}")
pprint(f"Target DOI: {target_doi}")


('Query: The work of  suggests a fundamental distinction in elemental '
 'abundances between closed and open magnetic structures, matching the nominal '
 'photospheric and coronal abundances, respectively.')
'Query pubdate: 1992-01-01'
'Target DOI: 10.1086/167871'


## Get the documents to be reranked

In [32]:
from database.database import Database
from Embedders import get_embedder


db = Database()
db.test_connection()

embedder = get_embedder(model_name='BAAI/bge-small-en', device='mps', normalize=True)
query_vector = embedder([query])[0]
print(f"Vector shape: {query_vector.shape}")

Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)
Vector shape: (384,)


In [33]:
results = db.query_vector_column(
    query_vector=query_vector,
    table_name='lib',
    target_column='bge_norm',
    pubdate=query_pubdate,
    top_k=10000,
)
print(f"Got {len(results)} results")

Got 10000 results


In [35]:
result_dois = [r.doi for r in results]
print(f"Target found: {target_doi in result_dois}")
print(f"Index of target DOI: {result_dois.index(target_doi)}")

Target found: True
Index of target DOI: 32


In [37]:
# Get the top results to be reranked
top_results = results[:33]
docs = [r.chunk for r in top_results]
print(len(docs))

33


In [38]:
co_results = client.rerank(model="rerank-v3.5", query=query, documents=docs, top_n=33)

In [42]:
for i, result in enumerate(co_results.results):
    print(i, result)

0 document=None index=7 relevance_score=0.56661874
1 document=None index=3 relevance_score=0.40147716
2 document=None index=10 relevance_score=0.39173222
3 document=None index=6 relevance_score=0.3538082
4 document=None index=0 relevance_score=0.3502663
5 document=None index=12 relevance_score=0.32830378
6 document=None index=15 relevance_score=0.24564147
7 document=None index=27 relevance_score=0.2372316
8 document=None index=9 relevance_score=0.22942336
9 document=None index=13 relevance_score=0.22553639
10 document=None index=32 relevance_score=0.22364864
11 document=None index=25 relevance_score=0.21650848
12 document=None index=19 relevance_score=0.1980335
13 document=None index=4 relevance_score=0.18643583
14 document=None index=20 relevance_score=0.14317599
15 document=None index=2 relevance_score=0.13353972
16 document=None index=22 relevance_score=0.13272823
17 document=None index=30 relevance_score=0.1287991
18 document=None index=14 relevance_score=0.1119482
19 document=None

Here we see that document index 32 (the actual row with the target DOI) has jumped to rank 10 after Cohere reranking

In [45]:
rerankings = []
for i, row in enumerate(df.itertuples()):
    print(i, end=", ")
    query_vector = embedder([row.sent_no_cit])[0]
    query_pubdate = row.pubdate
    target_dois = row.citation_dois
    query_results = db.query_vector_column(
        query_vector=query_vector,
        table_name='lib',
        target_column='bge_norm',
        pubdate=query_pubdate,
        top_k=10000,
    )

    result_dois = [r.doi for r in query_results]
    result_docs = [r.chunk for r in query_results]
    """
    A list of lists of dictionaries, 
        - Each row in the query dataframe gets one list
        - Each list contains a dict for each target DOI
        - If the target DOI in the query results, the dict contains:
            - target_at_idx: the index of the target DOI in the original results
            - co_reranks: the reranked results from Cohere
            - reranked_index: the index of the target DOI in the reranked results
            - rank_change: the difference between original and reranked index
        - If the target DOI is not in the query results, the dict contains:
            - target_at_idx: None
            - co_reranks: []
            - reranked_index: None
            - rank_change: None
    """
    rerank_results = []
    for target_doi in target_dois:
        rerank_result = {}
        if target_doi in result_dois:
            # Get the original index of the target DOI
            original_index = result_dois.index(target_doi)
            rerank_result['target_at_idx'] = original_index

            # Get the top results to be reranked
            co_results = client.rerank(
                model="rerank-v3.5", 
                query=row.sent_no_cit, 
                documents=result_docs, 
                top_n=len(query_results)
            )
            rerank_result['co_reranks'] = co_results.results

            # Get the reranked index of the target DOI
            for i, r in enumerate(co_results.results):
                if r.index == original_index:
                    rerank_result['reranked_index'] = i
                    break
            rerank_result['rank_change'] = original_index - i

        else:
            print(f"Target DOI {target_doi} not found in results")
            rerank_result['target_at_idx'] = None
            rerank_result['co_reranks'] = []
            rerank_result['reranked_index'] = None
            rerank_result['rank_change'] = None

        rerank_results.append(rerank_result)
    rerankings.append(rerank_results)

0, Target DOI 10.1016/j.epsl.2013.07.013 not found in results
1, 2, 3, 4, 5, 6, 7, 8, Target DOI 10.1088/1475-7516/2012/07/038 not found in results
9, 10, Target DOI 10.1016/j.gca.2021.07.031 not found in results
11, Target DOI 10.1086/151310 not found in results
12, 13, 14, Target DOI 10.1016/j.jafrearsci.2008.01.004 not found in results
15, 16, 17, 18, 19, 20, 21, 22, 23, Target DOI 10.1093/mnras/stz2467 not found in results
24, 25, 26, 27, Target DOI 10.1103/PhysRevFluids.4.013803 not found in results
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, Target DOI 10.1086/339896 not found in results
41, 42, Target DOI 10.1086/319078 not found in results
43, 44, 45, 46, Target DOI 10.1086/308508 not found in results
47, 48, 49, 50, 51, 52, 53, 54, Target DOI 10.1086/378218 not found in results
55, Target DOI 10.1111/j.1365-2966.2010.16517.x not found in results
56, 57, 58, Target DOI 10.1086/510850 not found in results
59, 60, 61, 62, 63, 64, 65, Target DOI 10.1086/381085 not found in

In [46]:
changes = []
for reranking in rerankings:
    for result in reranking:
        if result['rank_change'] is not None:
            changes.append(result['rank_change'])
print(f"Average rank change: {sum(changes) / len(changes)}")

Average rank change: -451.2967032967033


This demonstrates that on average, the Cohere re-ranking failed to make the target chunk higher ranked than before

In [51]:
pprint(rerankings[1][0]['rank_change'])
pprint(rerankings[1][0]['target_at_idx'])
pprint(rerankings[1][0]['reranked_index'])

-3654
596
4250


In [53]:
print(rerankings[1][0]['co_reranks'][4250])

document=None index=596 relevance_score=0.032298584
