In [1]:
import pandas as pd
from tqdm import tqdm
import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pprint import pprint
from citeline.database.milvusdb import MilvusDB
from citeline.embedders import Embedder
from citeline.query_expander import get_expander

db = MilvusDB()
print(db)

tqdm.pandas()

# Setup: load embedder, expander, dataset, db collection
embedder = Embedder.create("Qwen/Qwen3-Embedding-0.6B", device="mps", normalize=True)
print(embedder)

expander = get_expander("add_prev_3", path_to_data="../data/preprocessed/reviews.jsonl")
print(expander)

sample = pd.read_json("../data/dataset/nontrivial_100.jsonl", lines=True)
sample = sample.sample(20, random_state=42).reset_index(drop=True)

# Apply query expansion and embed the queries
sample["sent_no_cit"] = expander(sample)
sample["vector"] = sample.progress_apply(lambda row: embedder([row["sent_no_cit"]])[0], axis=1)

db.list_collections()
db.client.load_collection("qwen06_contributions")

sample.head()

<citeline.database.milvusdb.MilvusDB object at 0x17bbbb2d0>
Qwen/Qwen3-Embedding-0.6B, device=mps, normalize=True, dim=1024
QueryExpander(name=add_prev_3, data_length=2980)


100%|██████████| 20/20 [00:04<00:00,  4.30it/s]

Collections:
 - astrobert_chunks: 460801 entities
 - astrobert_contributions: 89860 entities
 - bge_chunks: 460801 entities
 - bge_contributions: 89860 entities
 - nasa_chunks: 460801 entities
 - nasa_contributions: 89860 entities
 - qwen06_chunks: 460801 entities
 - qwen06_contributions: 89860 entities
 - qwen06_findings_v2: 4342 entities
 - qwen8b_contributions: 89860 entities
 - specter_chunks: 460801 entities
 - specter_contributions: 89860 entities





Unnamed: 0,source_doi,sent_original,sent_no_cit,sent_idx,citation_dois,pubdate,resolved_bibcodes,sent_cit_masked,vector
0,10.1146/annurev-astro-081710-102521,Their abundance is important because molecular...,"In this limit, the important reactions are dis...",318,[10.1046/j.1365-8711.2002.04940.x],20110901,[2002MNRAS.329...18F],Their abundance is important because molecular...,"[-0.036558926, -0.016381508, -0.0110567855, 0...."
1,10.1146/annurev-astro-081817-051826,It is important to point out that the fraction...,Gesicki et al. (2014) derived masses and ages ...,231,"[10.1051/0004-6361/201220678, 10.1051/0004-636...",20180901,"[2013A&A...549A.147B, 2017A&A...605A..89B]",It is important to point out that the fraction...,"[-0.016821573, -0.014033405, -0.008635518, 0.0..."
2,10.1007/s00159-010-0029-x,How could the seed massive black holes have gr...,This argument is particularly important at ear...,259,"[10.1086/422910, 10.1086/427065, 10.1086/50744...",20100701,"[2004ApJ...613...36H, 2005ApJ...620...59S, 200...",How could the seed massive black holes have gr...,"[-0.018864796, -0.06410703, -0.0063532703, 0.0..."
3,10.1146/annurev.aa.31.090193.003441,Nature has somehow solved this problem in doub...,These models can reproduce the observed spectr...,457,[10.1086/161053],19930101,[1983ApJ...269..423R],Nature has somehow solved this problem in doub...,"[0.044838704, -0.013867467, -0.007506692, 0.04..."
4,10.1007/s00159-012-0055-y,"However, a similar linewidth–size scaling law ...","Size, internal velocity dispersion and column ...",306,[10.1051/0004-6361:20020629],20121101,[2002A&A...390..307O],"However, a similar linewidth–size scaling law ...","[-0.009350214, -0.056160886, -0.00880815, -0.0..."


In [2]:
def get_hard_records(example: pd.Series, n: int = 2) -> list[str]:
    """
    Overfetches 3*n most similar records (bc if two reps from same doc are in top n, we won't have n distinct non-target dois)

    Returns:
      A list of doi's, ordered by their max similarity to the query
    """
    results = db.search(
        collection_name="qwen06_contributions",
        query_records=[example.to_dict()],
        query_vectors=[example.vector],
        limit=3 * n,
    )
    results = results[0]  # db.search operates on lists of queries; we only need the first result

    # Filter results to non-targets only
    target_dois = set(example.citation_dois)
    non_target_results = [r for r in results if r["doi"] not in target_dois]
    return non_target_results[:n]


def get_similarity_to_targets(example: pd.Series) -> list[float]:
    """
    For each target doi in the example, computes the max similarity between the example and any record with that doi.

    Returns a list of scores in order of example.citation_dois
    """
    similarities = []
    for target_doi in example.citation_dois:
        results = db.select_by_doi(doi=target_doi, collection_name="qwen06_contributions")
        target_vectors = np.array(results["vector"].tolist())
        similarity_scores = np.dot(example.vector, target_vectors.T)
        similarities.append(np.max(similarity_scores))
    return similarities


def compute_margins(df: pd.DataFrame, target_col: str, hard_col: str, margin_col_name: str) -> None:
    """
    For each row in the DataFrame, computes the margin between each target similarity and the hardest non-target similarity.

    Args:
      df: DataFrame containing the data
      target_col: Name of the column with list of target similarities
      hard_col: Name of the column with list of hard non-target similarities
      margin_col_name: Name of the column to store the computed margins

    Returns:
      None (modifies df in place)
    """
    df[margin_col_name] = None
    for idx, row in df.iterrows():
        target_similarities = row[target_col]
        hardest_nontarget_similarity = max(row[hard_col])
        margins = [target_sim - hardest_nontarget_similarity for target_sim in target_similarities]
        df.at[idx, margin_col_name] = margins


# Compute target and hard similarities, then the margins
sample["target_similarities"] = sample.progress_apply(get_similarity_to_targets, axis=1)
sample["hard_dois"] = None
sample["hard_similarities"] = None
for idx, example in tqdm(sample.iterrows(), total=len(sample)):
    hard_records = get_hard_records(example, n=2)
    sample.at[idx, "hard_dois"] = [r["doi"] for r in hard_records]
    sample.at[idx, "hard_similarities"] = [r["metric"] for r in hard_records]

compute_margins(sample, target_col="target_similarities", hard_col="hard_similarities", margin_col_name="old_margins")
sample.head()

100%|██████████| 20/20 [00:00<00:00, 187.29it/s]
100%|██████████| 20/20 [00:01<00:00, 15.97it/s]


Unnamed: 0,source_doi,sent_original,sent_no_cit,sent_idx,citation_dois,pubdate,resolved_bibcodes,sent_cit_masked,vector,target_similarities,hard_dois,hard_similarities,old_margins
0,10.1146/annurev-astro-081710-102521,Their abundance is important because molecular...,"In this limit, the important reactions are dis...",318,[10.1046/j.1365-8711.2002.04940.x],20110901,[2002MNRAS.329...18F],Their abundance is important because molecular...,"[-0.036558926, -0.016381508, -0.0110567855, 0....",[0.6072211140974748],"[10.1088/0004-637X/703/2/1416, 10.1111/j.1365-...","[0.6299331188201904, 0.5999367833137512]",[-0.022712004722715617]
1,10.1146/annurev-astro-081817-051826,It is important to point out that the fraction...,Gesicki et al. (2014) derived masses and ages ...,231,"[10.1051/0004-6361/201220678, 10.1051/0004-636...",20180901,"[2013A&A...549A.147B, 2017A&A...605A..89B]",It is important to point out that the fraction...,"[-0.016821573, -0.014033405, -0.008635518, 0.0...","[0.5379291839783749, 0.5663974073131282]","[10.1093/mnras/stx373, 10.1051/0004-6361:20021...","[0.6454614996910095, 0.6320533156394958]","[-0.10753231571263466, -0.07906409237788137]"
2,10.1007/s00159-010-0029-x,How could the seed massive black holes have gr...,This argument is particularly important at ear...,259,"[10.1086/422910, 10.1086/427065, 10.1086/50744...",20100701,"[2004ApJ...613...36H, 2005ApJ...620...59S, 200...",How could the seed massive black holes have gr...,"[-0.018864796, -0.06410703, -0.0063532703, 0.0...","[0.7070825246559626, 0.6650816675743063, 0.662...","[10.1111/j.1365-2966.2006.10467.x, 10.1111/j.1...","[0.7770616412162781, 0.7540676593780518]","[-0.06997911656031552, -0.11197997364197176, -..."
3,10.1146/annurev.aa.31.090193.003441,Nature has somehow solved this problem in doub...,These models can reproduce the observed spectr...,457,[10.1086/161053],19930101,[1983ApJ...269..423R],Nature has somehow solved this problem in doub...,"[0.044838704, -0.013867467, -0.007506692, 0.04...",[0.5737327576399296],"[10.1086/164480, 10.1086/155083]","[0.6610962748527527, 0.6198835372924805]",[-0.0873635172128231]
4,10.1007/s00159-012-0055-y,"However, a similar linewidth–size scaling law ...","Size, internal velocity dispersion and column ...",306,[10.1051/0004-6361:20020629],20121101,[2002A&A...390..307O],"However, a similar linewidth–size scaling law ...","[-0.009350214, -0.056160886, -0.00880815, -0.0...",[0.4668131439478381],"[10.1086/169766, 10.1086/177465]","[0.6350903511047363, 0.5991689562797546]",[-0.1682772071568982]


In [9]:
margins = pd.to_numeric(sample.explode(column="old_margins")["old_margins"], errors="coerce").dropna()
margins.describe()

count    28.000000
mean     -0.072390
std       0.063403
min      -0.168277
25%      -0.114911
50%      -0.083214
75%      -0.031465
max       0.116795
Name: old_margins, dtype: float64

## Process the dois


In [10]:
dois_to_process = set(doi for dois in sample.citation_dois for doi in dois).union(
    doi for dois in sample.hard_dois for doi in dois
)
print(f"DOI's to process: {len(dois_to_process)}")

# Load research papers so we can get full text by doi
research = pd.read_json("../data/research_used.jsonl", lines=True)
research = research[research["doi"].isin(dois_to_process)].reset_index(drop=True)
print(f"Loaded {len(research)} research papers")


def doi_to_paper(doi: str) -> str:
    record = research[research["doi"] == doi].iloc[0]
    return record["title"] + "\n\n" + record["abstract"] + "\n\n" + record["body"]

# Test:
doi = list(dois_to_process)[0]
print(doi_to_paper(doi)[:500])

DOI's to process: 67
Loaded 67 research papers
Aperture Synthesis Maps of HDO Emission in Orion-KL

The 1<SUB>10</SUB>-1<SUB>11</SUB> transition of deuterated water has been mapped toward the Kleinmann-Low nebula in Orion with the Hat Creek millimeter interferometer. The synthesized beamwidth is 3arcsec.4. The "hot core", "plateau", and "compact ridge" emission regions can be identified on the maps. The compact ridge appears to consist of streamers of gas which connect to the hot core clump and are directed away from the source IRc 2. The HD


In [11]:
from openai import OpenAI
import os


def bind_client(func):
    """
    Decorator to bind OpenAI client to a function that will provide DeepSeek API access
    """
    client = OpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")

    def wrapper(*args, **kwargs):
        return func(client, *args, **kwargs)

    return wrapper


@bind_client
def deepseek(client, prompt: str) -> str:
    """
    Sends a prompt to the DeepSeek API (using DeepSeek-V3.1 non-thinking model)

    Expects a prompt that will instruct the model to respond with a JSON object.
    However, the function returns the raw string response, to allow for validation and
    error handling in multiple passes without losing the original response
    """
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[{"role": "system", "content": prompt}],
        stream=False,
        response_format={"type": "json_object"},
    )
    return response.choices[0].message.content


response = deepseek("Respond with a JSON object with keys 'greeting' and 'farewell'")
print(response)
print(json.loads(response))

{
    "greeting": "Hello!",
    "farewell": "Goodbye!"
}
{'greeting': 'Hello!', 'farewell': 'Goodbye!'}


In [None]:
with open("../src/citeline/llm/prompts/original_contributions_v2.txt", "r") as f:
    prompt_template = f.read()

with open("new_findings.jsonl", "w") as f:
    for doi in tqdm(dois_to_process):
        paper = doi_to_paper(doi)
        prompt = prompt_template.format(paper=paper)
        try:
            response = deepseek(prompt)
        except Exception as e:
            print(f"Error processing doi {doi}: {e}")
            continue
        try:
            data = json.loads(response)
            data["doi"] = doi
            f.write(json.dumps(data) + "\n")
        except json.JSONDecodeError:
            print(f"Failed to decode JSON for doi {doi}. Response was:\n{response}")
            with open("failed_dois.txt", "a") as f_fail:
                f_fail.write(doi + "\n")
            continue

 94%|█████████▍| 63/67 [22:26<01:28, 22.11s/it]

In [None]:
new_findings = pd.read_json("new_findings.jsonl", lines=True)
print(f"Loaded {len(new_findings)} new findings")

new_findings_exploded = new_findings.explode("findings")
new_findings_exploded["vector"] = embedder(new_findings_exploded["findings"].tolist()).tolist()
new_findings_exploded.head()

In [None]:
# Get new similarity to target
sample['new_target_similarities'] = None
sample['new_hard_similarities'] = None

def get_vectors_by_doi(doi: str) -> np.ndarray:
    return np.array(new_findings_exploded[new_findings_exploded["doi"] == doi]["vector"].tolist())


for idx, row in sample.iterrows():
    # For each target doi, compute the max similarity wrt the new embeddings
    query_vector = row['vector']
    new_similarities = []
    for target_doi in row['citation_dois']:
        target_vectors = get_vectors_by_doi(target_doi)
        new_similarities.append(np.max(np.dot(query_vector, target_vectors.T)))
    sample.at[idx, 'new_target_similarities'] = new_similarities

    # Collect all the hard vectors, compute the hard similarities
    new_hard_similarities = []
    for doi in row['hard_dois']:
        candidate_vectors = get_vectors_by_doi(doi)
        new_hard_similarities.append(np.max(np.dot(query_vector, candidate_vectors.T)))
    sample.at[idx, 'new_hard_similarities'] = new_hard_similarities

sample.head()

In [None]:
compute_margins(sample, target_col="new_target_similarities", hard_col="new_hard_similarities", margin_col_name="new_margins")
sample.head()

In [None]:
def compute_margin_diffs(df: pd.DataFrame, new_col: str, ref_col: str) -> pd.Series:
    new_values = df[new_col].explode().tolist()
    ref_values = df[ref_col].explode().tolist()
    diffs = [new - ref for new, ref in zip(new_values, ref_values)]
    return pd.Series(diffs)

diffs = compute_margin_diffs(sample, new_col="new_margins", ref_col="old_margins")
print(diffs.describe())



## Error analysis

Let's look at where the new margin is still negative (the target document vectors aren't as close to the query as the hard examples)

In [None]:
error_rows = sample[sample['new_margins'].apply(lambda margins: any(margin < 0 for margin in margins))]
error_rows

In [None]:
def analyze_error_row(idx: int) -> None:

    example = error_rows.iloc[idx]
    print("Original sentence:")
    pprint(example['sent_original'])

    hardest_idx = np.argmax(example['new_hard_similarities'])
    hard_doi = example['hard_dois'][hardest_idx]
    hard_findings = new_findings_exploded[new_findings_exploded['doi'] == hard_doi]
    hard_vectors = np.array(hard_findings['vector'].tolist())
    hard_similarities = np.dot(example['vector'], hard_vectors.T)
    hardest_indices = np.argsort(-hard_similarities)[:3]
    for idx in hardest_indices:
        print(f"Similarity: {hard_similarities[idx]:.4f}, DOI: {hard_findings.iloc[idx]['doi']}")
        pprint(hard_findings.iloc[idx]['findings'])
        print("-----")

analyze_error_row(0)

### Revision 2

In [None]:
with open("../src/citeline/llm/prompts/original_contributions_v2.txt", "r") as f:
    prompt_template = f.read()

NEW_FINDINGS_FILENAME = "new_findings_v2.jsonl"

with open(NEW_FINDINGS_FILENAME, "w") as f:
    for doi in tqdm(dois_to_process):
        paper = doi_to_paper(doi)
        prompt = prompt_template.format(paper=paper)
        try:
            response = deepseek(prompt)
        except Exception as e:
            print(f"Error processing doi {doi}: {e}")
            continue
        try:
            data = json.loads(response)
            data["doi"] = doi
            f.write(json.dumps(data) + "\n")
        except json.JSONDecodeError:
            print(f"Failed to decode JSON for doi {doi}. Response was:\n{response}")
            with open("failed_dois.txt", "a") as f_fail:
                f_fail.write(doi + "\n")
            continue

In [None]:
new_findings = pd.read_json(NEW_FINDINGS_FILENAME, lines=True)
print(f"Loaded {len(new_findings)} new findings")

new_findings_exploded = new_findings.explode("findings")
new_findings_exploded["vector"] = embedder(new_findings_exploded["findings"].tolist()).tolist()
new_findings_exploded.head()

In [None]:
# Save previous iteration and reset df for new results
sample_old = sample.copy()

# Get new similarity to target
sample["new_target_similarities"] = None
sample["new_hard_similarities"] = None

for idx, row in sample.iterrows():
    # For each target doi, compute the max similarity wrt the new embeddings
    query_vector = row["vector"]
    new_similarities = []
    for target_doi in row["citation_dois"]:
        target_vectors = get_vectors_by_doi(target_doi)
        new_similarities.append(np.max(np.dot(query_vector, target_vectors.T)))
    sample.at[idx, "new_target_similarities"] = new_similarities

    # Collect all the hard vectors, compute the hard similarities
    new_hard_similarities = []
    for doi in row["hard_dois"]:
        candidate_vectors = get_vectors_by_doi(doi)
        new_hard_similarities.append(np.max(np.dot(query_vector, candidate_vectors.T)))
    sample.at[idx, "new_hard_similarities"] = new_hard_similarities

compute_margins(
    sample, target_col="new_target_similarities", hard_col="new_hard_similarities", margin_col_name="new_margins"
)
sample.head()

diffs = compute_margin_diffs(sample, new_col="new_margins", ref_col="old_margins")
print(diffs.describe())

In [None]:
error_rows = sample[sample["new_margins"].apply(lambda margins: any(margin < 0 for margin in margins))]
print(f"Number of rows with negative new margins: {len(error_rows)}")
error_rows

In [None]:
# Print the target contributions for an error row
idx = 0
analyze_error_row(idx)

def print_target_contributions(idx: int) -> None:
    row = error_rows.iloc[idx]
    print("Original sentence:")
    print(row["sent_original"])

    target_dois = row["citation_dois"]
    target_records = {doi: new_findings_exploded[new_findings_exploded["doi"] == doi]['findings'] for doi in target_dois}
    pprint("Target findings:")
    for doi in target_records:
        print(f"DOI: {doi}")
        for i, finding in enumerate(target_records[doi]):
            print(f"{i}: {finding}")
        print("-----")
print(f"Sentence in context:\n{error_rows.iloc[idx]['sent_no_cit']}")
print_target_contributions(idx)

In [None]:
error_rows.iloc[idx]['sent_no_cit']

In [None]:
target_vector = embedder(
    [
        "Deep optical images shows a faint elliptical ring structure orbiting the spiral galaxy NGC 5907",
    ]
)[0]
# query_vector = error_rows.iloc[0]["vector"]
query_vector = embedder(
    [
        "However, deep optical images of a number of spiral galaxies, such as NGC 253, M 83, M 104, NGC 2855, (Malin and Hadley 1997) and NGC 5907 (), do show unusual, faint features in their surroundings.",
    ]
)[0]
print(f"Cosine similarity: {query_vector.dot(target_vector):.4f}")

In [None]:
hard_vector = embedder(["Most extended and complete luminosity function obtained for Galactic bulge to date"])[0]
print(f"Cosine similarity: {np.dot(hard_vector, query_vector):.4f}")

In [None]:
for i, row in new_findings_exploded[new_findings_exploded["doi"] == "10.1086/164480"].iterrows():
    print(f"Finding {i}:")
    pprint(row["findings"])
    print("-----")