In [None]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from citeline.embedders import Embedder
from database.milvusdb import MilvusDB
from citeline.query_expander import get_expander

tqdm.pandas()
embedder = Embedder.create("Qwen/Qwen3-Embedding-0.6B", device="mps", normalize=True, for_queries=True)

In [2]:
QUERY_EXPANSION = 'add_prev_3'

In [3]:
examples = pd.read_json("data/dataset/nontrivial_checked.jsonl", lines=True)
expander = get_expander(QUERY_EXPANSION, path_to_data="data/preprocessed/reviews.jsonl")
print(f"Using query expansion: {expander}")
examples["sent_no_cit"] = expander(examples)

# Add vector column to examples
examples["vector"] = examples.progress_apply(lambda row: embedder([row["sent_no_cit"]])[0], axis=1)

# Denormalize on citation_dois (targets)
examples = examples.explode("citation_dois", ignore_index=True)
print(f"Number of samples after denormalization: {examples.shape[0]}")
examples.rename(columns={"citation_dois": "target_doi"}, inplace=True)

Using query expansion: QueryExpander(name=add_prev_3, data_length=2980)


100%|██████████| 14735/14735 [18:30<00:00, 13.27it/s] 

Number of samples after denormalization: 18801





In [4]:
examples.head()

Unnamed: 0,source_doi,sent_original,sent_no_cit,sent_idx,target_doi,pubdate,resolved_bibcodes,sent_cit_masked,vector
0,10.1016/j.newar.2024.101694,"Subsequently, Andrews et al. (2017) selected a...",1 pc. Similar separation distributions had bee...,58,10.1093/mnras/stx2000,20240601,[2017MNRAS.472..675A],"Subsequently, [REF] selected a wide binary can...","[-0.013162542, -0.09026443, -0.0120118465, -0...."
1,10.1016/j.newar.2024.101694,Andrews et al. (2017) investigated how the sep...,"Subsequently, Andrews et al. (2017) selected a...",61,10.1093/mnras/stx2000,20240601,[2017MNRAS.472..675A],[REF] investigated how the separation of their...,"[-0.07659445, -0.064264126, -0.007819046, 0.00..."
2,10.1016/j.newar.2024.101694,This led Andrews et al. (2017) to conclude tha...,Andrews et al. (2017) investigated how the sep...,64,10.1093/mnras/stx2000,20240601,[2017MNRAS.472..675A],This led [REF] to conclude that most of the pa...,"[-0.041363332, -0.067232065, -0.00943872, 0.03..."
3,10.1016/j.newar.2024.101694,It may also owe in part to the mass ratio dist...,The sample contains 97 resolved WD+MS binaries...,90,10.1093/mnras/stz2480,20240601,[2019MNRAS.489.5822E],It may also owe in part to the mass ratio dist...,"[-0.04453195, -0.07250524, -0.009316281, 0.052..."
4,10.1016/j.newar.2024.101694,Hwang et al. (2022c) used a related method to ...,This approach forward-models the distribution ...,110,10.3847/2041-8213/ac7c70,20240601,[2022ApJ...933L..32H],[REF] used a related method to study the eccen...,"[-0.04475006, -0.01653341, -0.007177436, 0.053..."


In [None]:
db = MilvusDB()
def most_similar_to_query(example: pd.Series) -> np.ndarray:
    """
    Takes in an example (with 'vector' column already set), and from the candidates
    (returned entities with that doi from the database), returns the vector most similar
    to the example's vector.

    """
    # Converts 'vector' column to rows * dim array, holding the candidate vectors
    candidates = db.select_by_doi(example.target_doi, collection_name="qwen06_chunks")
    candidate_vectors = np.stack(candidates["vector"])
    best_idx = np.argmax(np.dot(candidate_vectors, example["vector"]))
    best_vector = candidate_vectors[best_idx]
    return best_vector

In [None]:
Q = np.array(examples['vector'].tolist()).T

V = np.array([most_similar_to_query(row) for _, row in tqdm(examples.iterrows(), total=len(examples))]).T
print("Q shape:", Q.shape)
print("V shape:", V.shape)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

M = V @ Q.T
plt.figure(figsize=(9, 7))
sns.heatmap(M, cmap="coolwarm", center=0, cbar_kws={"label": "value"})  # center=0 for diverging data
plt.title("Matrix heatmap (seaborn)")
plt.tight_layout()
plt.savefig("heatmap_seaborn.png", dpi=200)
plt.show()

In [None]:

U, S, Vt = np.linalg.svd(M)
R = U @ Vt

print("det R:", np.linalg.det(R))
# if np.linalg.det(R) < 0:
#     D = np.eye(R.shape[0])
#     D[-1, -1] = -1
#     R = U @ D @ Vt
# print("corrected det R:", np.linalg.det(R))

In [None]:
print("rank approx:", np.sum(S > 1e-8))
print("singular values (top/last):", S[:30], S[-5:])

In [None]:
research_dois = set(examples.target_doi)
print(research_dois)

In [None]:
import random

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


def negative_vectors(example: pd.Series, n=5):
    # Get all target DOIs for this example, sample from those not cited
    rows = examples[(examples.source_doi == example.source_doi) & (examples.sent_idx == example.sent_idx)]
    citation_dois = set(rows.target_doi)
    neg_dois = list(research_dois - citation_dois)
    neg_samples = random.sample(neg_dois, n)
    
    # Get vectors for negative samples
    neg_vectors = np.zeros((n, embedder.dim))
    for i, doi in enumerate(neg_samples):
        records = db.select_by_doi(doi, collection_name="qwen06_chunks")
        sample_record = records.sample(n=1).iloc[0]
        neg_vectors[i] = np.array(sample_record['vector'])
    return neg_vectors


res = negative_vectors(examples.iloc[998])
print(res)

Citation DOIs: {'10.1046/j.1365-8711.2000.03810.x', '10.1086/316394', '10.1086/186883'}

In [None]:
differences_to_target = []
differences_to_negative = []
for _, row in examples.iterrows():
    query_vector = row['vector']
    target_vector = most_similar_to_query(row)
    aligned_query_vector = R @ query_vector
    before = cosine_similarity(query_vector, target_vector)
    after = cosine_similarity(aligned_query_vector, target_vector)

    differences_to_target.append(after - before)
    neg_vectors = negative_vectors(row, n=5)
    batch_distance_to_negative = []
    for neg_vector in neg_vectors:
        before_neg = cosine_similarity(query_vector, neg_vector)
        after_neg = cosine_similarity(aligned_query_vector, neg_vector)
        batch_distance_to_negative.append(after_neg - before_neg)
        differences_to_negative.append(after_neg - before_neg)
    print(_)
    print(f"Improvement to target: {after - before:.4f}")
    print(f"Distance to negatives: {np.mean(batch_distance_to_negative):.4f} ± {np.std(batch_distance_to_negative):.4f}")
    print("---")
print(f"Average improvement (target): {np.mean(differences_to_target):.6f} ± {np.std(differences_to_target):.6f}")
print(f"Average improvement (negative): {np.mean(differences_to_negative):.6f} ± {np.std(differences_to_negative):.6f}")

In [None]:
print(R.shape)

In [None]:
import matplotlib.pyplot as plt
# U, S, Vt = np.linalg.svd(M)
# print("U shape:", U.shape)
# print("S shape:", S.shape)
# print("Vt shape:", Vt.shape)
# print("Vt_trunc shape:", Vt[:1, :].shape)
# print("S_trunc shape:", S[:1].shape)
# print("U_trunc shape:", U[:, 0:1].shape)
# trunc = U[:, 0:1] @ (S[:1] * Vt[0:1, :])
# print("trunc shape:", trunc.shape)
# print("trunc rank:", np.linalg.matrix_rank(trunc))
# print(trunc)


plt.semilogy(S)  # plot singular values

In [None]:
np.save('qwen06_chunks_rotation_n2000.npy', R)

In [None]:
sigma1 = S[0]
eps = 1e-6  # try 1e-6 .. 1e-8 as needed
r = 50
# r = np.searchsorted(S / sigma1 < eps, True)  # first index where ratio < eps
print(f"Chosen rank r: {r}")
if r == 0:
    r = len(S)  # fallback if no small ones found
# alternative: r = np.searchsorted(np.cumsum(S**2) / np.sum(S**2), 0.99) + 1

# build small rotation in top-r
Ur = U[:, :r]  # n x r
Vr = Vt[:r, :].T  # n x r  (since Vt[:r,:] is r x n)
Msmall = Ur.T @ M @ Vr  # should be r x r but simpler compute: Tproj.T @ Qproj if you had them
# simpler: compute r x r cross-covariance directly via projections:
# Qproj = Q_all @ Ur  # expensive if many, but doable; here we reuse M decomposition
# but we can use SVD of the small Mslice: compute Us, Ss, Vts = svd(Ur.T @ M @ Vr)

# directly SVD the r x r matrix (numerically stable):
Us, Ss, Vts = np.linalg.svd(Ur.T @ M @ Vr, full_matrices=False)
Rsmall = Us @ Vts
# ensure proper rotation (det +1)
if np.linalg.det(Rsmall) < 0:
    D = np.eye(r)
    D[-1, -1] = -1
    Rsmall = Us @ D @ Vts

# map basis Ur to rotated basis Ur @ Rsmall @ Ur.T, then add identity on complement
R_full = Ur @ Rsmall @ Ur.T + (np.eye(1024) - Ur @ Ur.T)
print("det R_full:", np.linalg.det(R_full))
print(R_full.shape)
print(R_full)

In [None]:
np.save('qwen06_chunks_fullR.npy', R)

In [None]:
v1 = Vt[0, :]
u1 = U[:, 0]
print("v1 shape:", v1.shape)
print("u1 shape:", u1.shape)

In [None]:
query_vectors = np.stack(examples['vector'].to_numpy())
print("query_vectors shape:", query_vectors.shape)
X = np.dot(v1, query_vectors.T)
print(X.shape)

In [None]:
target_vectors = np.stack([most_similar_to_query(row) for _, row in tqdm(examples.iterrows(), total=len(examples))])
Y = np.dot(u1, target_vectors.T)
print(Y.shape)
plt.scatter(X, Y)

In [None]:
np.corrcoef(X, Y)