In [1]:
import os
import sys

project_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_path not in sys.path:
    sys.path.append(project_path)

from dotenv import load_dotenv

load_dotenv(os.path.join(project_path, '.env'), override=True)

True

In [2]:
import pyalex
from llmrankers.setwise import OpenAiSetwiseLlmRanker
from llmrankers.rankers import SearchResult
from itertools import chain


# monkey patch pyalex.Work to implement __hash__ (using id)
def custom_hash(self):
    return id(self)


pyalex.Work.__hash__ = custom_hash

In [3]:
from cli.main import Services
from core.dataclasses.data_classes import Work, ScoredWork

services = Services()
publication_service = services.publication_service
user_service = services.user_service
publication_repository = publication_service.publication_repository
user_repository = user_service.user_repository

In [4]:
# db setup
from db.utils import recreate_tables
def clean_tables():
    # truncate publication and user table (cascading)
    publication_repository.truncate() 
    user_repository.truncate()

In [5]:
from core.sqlalchemy_models.openalex.topic import Topic
def create_profile(user_name: str, publication: pyalex.Work) ->list[str]:
    topic_associations, _ = user_service.set_area_of_interest(
        user_name, publication["abstract"], align_description=False, n_topics=10
    )
    topics = [association.topic.name for association in topic_associations]
    return topics

In [6]:
import math


def get_random_publications_openalex(n: int) -> list[pyalex.Work]:
    pager = (pyalex.Works().sample(n, seed=42)
             .filter(has_abstract=True)
             .filter(has_references=True)
             .paginate(method="page", per_page=100))
    publications = []
    for page in pager:
        publications.extend(page)
    return publications


def filter_ineligible_publications(publications: list[pyalex.Work], max: int) -> list[pyalex.Work]:
    filtered_publications = []
    for publication in publications:
        if len(publication["referenced_works"]) >= 20:
            filtered_publications.append(publication)
            if len(filtered_publications) >= max:
                break

    return filtered_publications


def get_referenced_publications(publication: pyalex.Work, require_abstract=False) -> set[pyalex.Work]:
    referenced_publications = set()
    reference_ids = set(publication["referenced_works"])
    reference_ids = ["W" + reference.split("W")[-1] for reference in
                     reference_ids]  # use shortened ids because query gets too long otherwise 
    for i in range(0, len(reference_ids), 100):
        query = (
            pyalex.Works()
            .filter(openalex="|".join(reference_ids[i: i + 100]))
        )        
        if require_abstract:
            query = query.filter(has_abstract=True)
            
        for pyalex_work in chain(*query.paginate(per_page=100, n_max=100)):
            if pyalex_work["id"] != "https://openalex.org/W4285719527":  # deleted works are represented by this dummy entity
                referenced_publications.add(pyalex_work)
                
    return referenced_publications

def insert_publications(publications: list[pyalex.Work]):
    from datetime import datetime
    from core.llm_interfaces import OpenAIInterface
    # first, convert to our own dataclass 
    publications_converted = [Work(publication) for publication in publications]
    
    openai_interface = OpenAIInterface(print_usage_info=False)
    abstracts = [work.abstract for work in publications_converted]
    abstract_embeddings = openai_interface.create_embedding_batch(abstracts)

    # make sure all referenced works are in the database
    for i, work in enumerate(publications_converted):
        publication = publication_repository.get_by_openalex_id(work.id)
        if not publication:
             publication_repository.create(
                openalex_id=work.id,
                title=work.title,
                authors=work.authors,
                abstract=work.abstract,
                published=work.publication_date,
                accessed=datetime.now(),
                embedding=abstract_embeddings[i],
            )
    print("Rebuilding BM25 index...")
    publication_repository.rebuild_bm25()
    print("Rebuilt BM25 index.")
    
def check_all_publications_in_db(publications: list[pyalex.Work]) -> list[str]:
    """
    Checks whether all cited works are actually in the database, and returns a list of those that are not.
    """
    missed_publications = []
    for publication in publications:
        result = publication_repository.get_by_openalex_id(int(publication["id"].split("W")[-1]))
        if not publication:
            missed_publications.append(result)
    return missed_publications

def convert_to_pyalex_works(works: list[Work]) -> list[pyalex.Work]:
    pyalex_works = []
    openalex_ids = [f"W{work.id}" for work in works]
    
    for i in range(0, len(openalex_ids), 100):
        query = pyalex.Works().filter(openalex="|".join(openalex_ids[i : i + 100]))
        for pyalex_work in chain(*query.paginate(per_page=100, n_max=100)):
            pyalex_works.append(pyalex_work)
        
    return pyalex_works

def get_retrieval_results(publication: pyalex.Work, n: int = 100) -> list[Work]:
    from core.services.publication_service import SearchType
    start_date = datetime(1900, 1, 1)
    query = publication["abstract"]
    results_hybrid = publication_service.get_relevant_works_for_query(query, n + 10, start_date, SearchType.HYBRID) # some buffer to account for possible duplicates
    # remove works with score < 0.95 (likely the query document itself)
    results_hybrid = [scored_work.work for scored_work in results_hybrid if scored_work.score < 0.95]

    return convert_to_pyalex_works(results_hybrid[:n])

def rerank(query_work: pyalex.Work, retrieval_results: list[pyalex.Work]) -> tuple[list[Work], float]:
    reranker = OpenAiSetwiseLlmRanker(model_name_or_path="gpt-4o-mini-2024-07-18", api_key=os.getenv("OPENAI_API_KEY"), method="heapsort", num_child=2, k=10)
    price_per_input_token = 0.15 / 1e6
    price_per_output_token = 0.6 / 1e6
    total_usage = {"total_prompt_tokens": 0, "total_completion_tokens": 0}
    
    docs = [SearchResult(docid=result["id"], text=result["abstract"], score=None) for result in retrieval_results]
    reranked_docs, usage = reranker.rerank(query_work["abstract"], docs)
    reranked_works = []
    for doc in reranked_docs:
            reranked_works.extend([work for work in retrieval_results if work["id"] == doc.docid])        
    
    total_usage["total_prompt_tokens"] += usage["total_prompt_tokens"]
    total_usage["total_completion_tokens"] += usage["total_completion_tokens"]
    cost = total_usage["total_prompt_tokens"] * price_per_input_token + total_usage["total_completion_tokens"] * price_per_output_token
    print(f"Reranking cost: {cost}")
    
    return reranked_works, cost

from dataclasses import dataclass
@dataclass
class ResultStats:
        run: int
        type: str
        query_work: str
        result_work: str
        result_rank: int
        common_references: int
        max_common_references: int
        is_reference_of_citing_work: bool
        num_query_work_references_in_corpus: int
        
        
        

def calculate_stats(query_work: pyalex.Work, query_work_references_in_corpus: list[pyalex.Work], retrieval_results: list[pyalex.Work], run: int, type: str) -> list[ResultStats]:
    stats = []
    
    query_work_reference_ids = set(query_work["referenced_works"])
    num_query_work_references_in_corpus = len(query_work_references_in_corpus) # relevant for calculating (retrieved_references / max_retrieved_references)
    
    for i, retrieved_work in enumerate(retrieval_results):
        result_reference_ids = set(retrieved_work["referenced_works"])
        common_references = query_work_reference_ids.intersection(result_reference_ids)
        max_common_references = min(len(query_work_reference_ids), len(result_reference_ids))
        is_reference_of_citing_work = retrieved_work["id"] in query_work_reference_ids
        
        stats.append(
            ResultStats(
                run=run,
                type=type,
                query_work=query_work["id"],
                result_work = retrieved_work["id"],
                result_rank = i + 1, # 1-indexed
                common_references= len(common_references),
                max_common_references=max_common_references,
                is_reference_of_citing_work=is_reference_of_citing_work,
                num_query_work_references_in_corpus = num_query_work_references_in_corpus   
        ))   
    
    return stats   

In [None]:
from datetime import datetime
import pandas as pd 

In [7]:
from datetime import datetime
import pandas as pd

timestamp = datetime.now().isoformat()

n = 100
limit_initial_retrieval = 2000  
start_date = datetime(2020, 1, 1)
random_publications = get_random_publications_openalex(n * 5)
filtered_publications = filter_ineligible_publications(random_publications, max=n)

reranking_cost = 0.0
stats = []
for i, query_work in enumerate(filtered_publications):
    print(f"Run: {i} --- Title: {query_work['title']}, id: {query_work['id']}")
    # clean db and create user + profile
    clean_tables()
    user_service.create_user("eval_user", "Eval User", "-")
    topics = create_profile("eval_user", query_work)
    print(f"Top 3 topics: {topics[:3]}")
    publications_retrieved_for_topics = publication_service.initialize_for_user("eval_user", start_date=start_date, limit=limit_initial_retrieval)
    print(f"Retrieved and inserted {len(publications_retrieved_for_topics)} publications.")    
    referenced_works = get_referenced_publications(query_work, require_abstract=True)
    insert_publications(referenced_works)
    print(f"Inserted {len(referenced_works)} referenced works.")
    # sanity check: are all referenced works in the database?
    print(f"Found all referenced works in the database? {len(check_all_publications_in_db(referenced_works)) == 0}")
    retrieval_results = get_retrieval_results(query_work, 100)
    print(f"Top 5 results: {"".join(f'\n\t{result["title"]}' for result in retrieval_results[:5])}")
    print(f"Reranking...")
    reranked_results, cost = rerank(query_work, retrieval_results)
    reranking_cost += cost
    print(f"Reranked results: {"".join(f'\n\t{result["title"]}' for result in reranked_results[:5])}")
    print(f"Accumulated reranking cost: {reranking_cost}")
    retrieval_stats = calculate_stats(query_work=query_work,query_work_references_in_corpus=referenced_works, retrieval_results=retrieval_results, run=i, type="retrieval")
    reranking_stats = calculate_stats(query_work=query_work,query_work_references_in_corpus=referenced_works, retrieval_results=reranked_results, run=i, type="reranking")
    
    stats.extend(retrieval_stats)
    stats.extend(reranking_stats)
    
    df = pd.DataFrame(stats)
    # save to timestamped pickle file
    df.to_pickle(f"eval_common_references_{timestamp}.pkl")
    
    print("-" * 80)

Run: 0 --- Title: Original Articles Immunoreactivity of Stat5 Phosphorylated on Tyrosine as a Cell-Based Measure of Bcr/Abl Kinase Activity, id: https://openalex.org/W316998542
Model: text-embedding-3-large, Tokens: 120, Cost: $0.00, Accumulated cost: $0.00
Top 3 topics: ['Role of STAT3 in Cancer Inflammation and Immunity', 'Transforming Growth Factor Beta Signaling Pathway', 'Calcineurin-NFAT Signaling in Transcriptional Regulation']
Model: text-embedding-3-large, Tokens: 710638, Cost: $0.09, Accumulated cost: $0.09
Retrieved and inserted 2000 publications.
Rebuilding BM25 index...
Rebuilt BM25 index.
Inserted 34 referenced works.
Found all referenced works in the database? True
Model: text-embedding-3-large, Tokens: 120, Cost: $0.00, Accumulated cost: $0.09
Top 5 results: 
	The JAK/STAT signaling pathway: from bench to clinic
	The role of shared receptor motifs and common stat proteins in the generation of cytokine pleiotropy and redundancy by IL-2, IL-4, IL-7, IL-13, and IL-15
	The 