In [1]:
from database.database import Database


db = Database()
db.test_connection()

Database         User             Host                             Port            
citelinedb       bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.5 (Homebrew) on aarch64-apple-darwin24.4.0, compiled by Apple clang version 17.0.0 (clang-1700.0.13.3), 64-bit',)


In [2]:
import pandas as pd
import torch
from Embedders import get_embedder

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
samples = pd.read_json("data/dataset/split/small_train.jsonl", lines=True)
bge_embedder = get_embedder("BAAI/bge-small-en", device=device, normalize=True)
print(f"Loaded samples: {len(samples)}")
print("Embedder:", bge_embedder)

Loaded samples: 10
Embedder: BAAI/bge-small-en, device=mps, normalize=True


In [8]:
# Get the embedding of the sentence (no citation)
for sample in samples.itertuples():
    pubdate = sample.pubdate
    query_vector = bge_embedder([sample.sent_no_cit])[0]

    # Query the database for similar sentences
    results = db.query_vector_column(
        query_vector=query_vector,
        table_name="lib",
        target_column="bge_norm",
        metric="vector_cosine_ops",
        pubdate=pubdate,
        use_index=True,
        top_k=40,
        ef_search=40
    )

    result_dois = set(result.doi for result in results)
    target_dois = set(sample.citation_dois)

    all_citations_in_top_k = target_dois.issubset(result_dois)
    pct_in_top_k = len(target_dois.intersection(result_dois)) / len(target_dois) if target_dois else 0
    print(f"{len(target_dois)} citations, {len(result_dois)} results, "
          f"{pct_in_top_k:.2%} of citations in top-k")

1 citations, 27 results, 0.00% of citations in top-k
1 citations, 35 results, 0.00% of citations in top-k
4 citations, 19 results, 0.00% of citations in top-k
3 citations, 32 results, 0.00% of citations in top-k
1 citations, 31 results, 0.00% of citations in top-k
1 citations, 27 results, 0.00% of citations in top-k
2 citations, 22 results, 50.00% of citations in top-k
1 citations, 19 results, 0.00% of citations in top-k
1 citations, 26 results, 0.00% of citations in top-k
1 citations, 17 results, 100.00% of citations in top-k


In [None]:
sample = samples.iloc[9]
print(f"Sample 9: {sample.sent_no_cit}")
print(f"Sample 9 DOIs: {sample.citation_dois}")

query_results = db.query_vector_column(
    query_vector=bge_embedder([sample.sent_no_cit])[0],
    table_name="lib",
    target_column="bge_norm",
    metric="vector_cosine_ops",
    pubdate=sample.pubdate,
    use_index=True,
    top_k=20,
    ef_search=1000
)
print(f"Sample 9 query results: {len(query_results)}")
unique_dois = set(result.doi for result in query_results)
print(f"Sample 9 unique DOIs in results: {len(unique_dois)}")

In [None]:
# Get the unique dois as a list (to preserve order)
unique_dois_list = []
for result in query_results:
    if result.doi not in unique_dois_list:
        unique_dois_list.append(result.doi)
print(unique_dois_list)

In [None]:
# Get the candidate papers for each doi
candidates = [db.get_paper_by_doi(doi) for doi in unique_dois_list]

In [None]:
from apis.openai_client import deepseek_citation_validator_using_openai

validation_responses = deepseek_citation_validator_using_openai(query=sample.sent_no_cit, candidates=candidates)

for res in validation_responses:
    print(f"Validation response: {res}")

In [None]:
vals = validation_responses[0]


In [None]:
for val in vals:
    print(val.choices[0].message.content)

In [None]:
print(unique_dois_list[:2])
print(sample.citation_dois)

In [None]:
# from apis.openai_client import deepseek_citation_validator

# deepseek_results = deepseek_citation_validator(query=sample.sent_no_cit, candidates=candidates)

In [None]:
from dotenv import load_dotenv
import os
load_dotenv()
assert "DEEPSEEK_API_KEY" in os.environ, "Please set DEEPSEEK_API_KEY in your environment variables."


In [None]:
from openai import OpenAI

client = OpenAI(
    api_key=os.environ["DEEPSEEK_API_KEY"],
    base_url="https://api.deepseek.com",
)

with open("llm/prompts/deepseek_citation_identification.txt", "r") as f:
    DEEPSEEK_CITATION_IDENTIFICATION_PROMPT = f.read()

OUTPUT_FORMAT = {
    "type": "json_schema",
    "json_schema": {
        "type": "json_object",
        "properties": {"should_cite": {"type": "boolean"}},
        "required": ["should_cite"],
    },
}

def ds_formatted(query, candidates):
    prompts = [
        DEEPSEEK_CITATION_IDENTIFICATION_PROMPT.format(sentence=query, paper=paper)
        for paper in candidates
    ]
    results = [
        client.chat.completions.create(
            model="deepseek-chat",
            temperature=0.0,
            messages=[{"role": "system", "content": prompt}],
            response_format={"type": "json_object"},
            stream=False,
        )
        for prompt in prompts
    ]
    return results

In [None]:
trial_results = ds_formatted(query=sample.sent_no_cit, candidates=candidates[:1])


In [None]:
print(trial_results[0].choices[0].message.content)