In [41]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from embedders import Embedder
from database.milvusdb import MilvusDB
from query_expander import get_expander

from compute_difference_vector import get_sample_df
tqdm.pandas()
embedder = Embedder.create("Qwen/Qwen3-Embedding-0.6B", device="mps", normalize=True, for_queries=True)
samples = get_sample_df(1000, embedder)

Loaded 1000 examples
Using query expansion: QueryExpander(name=add_prev_3, data_length=2980)


100%|██████████| 1000/1000 [02:08<00:00,  7.76it/s]

Number of samples after denormalization: 1265





In [42]:
samples.head()

Unnamed: 0,source_doi,sent_original,sent_no_cit,sent_idx,target_doi,pubdate,resolved_bibcodes,sent_cit_masked,vector
0,10.1146/annurev-astro-082812-141031,"Indeed, Valenti et al. (2009) have argued that...","Jordan et al. (2012) , Kromer et al. (2013) , ...",595,10.1038/nature08023,20140801,[2009Natur.459..674V],"Indeed, [REF] have argued that SNe Iax are act...","[0.020763418, 0.014040946, -0.009526509, 0.012..."
1,10.1007/s00159-011-0047-3,Using a dedicated VLT/ISAAC multi-epoch SN sur...,"At even higher cluster redshift z ∼ 1, the Sup...",1029,10.1051/0004-6361/200911982,20111101,"[2009A&A...507...61S, 2009A&A...507...71G]",Using a dedicated VLT/ISAAC multi-epoch SN sur...,"[-0.005719765, -0.005792628, -0.012335606, 0.0..."
2,10.1007/s00159-011-0047-3,Using a dedicated VLT/ISAAC multi-epoch SN sur...,"At even higher cluster redshift z ∼ 1, the Sup...",1029,10.1051/0004-6361/200811254,20111101,"[2009A&A...507...61S, 2009A&A...507...71G]",Using a dedicated VLT/ISAAC multi-epoch SN sur...,"[-0.005719765, -0.005792628, -0.012335606, 0.0..."
3,10.1146/annurev-astro-081811-125615,"Madau et al. (1996 , 1998b ) and Lilly et al. ...",Redshifts z >4 have been confirmed from CO mea...,904,10.1093/mnras/283.4.1388,20140801,"[1996MNRAS.283.1388M, 1998ApJ...498..106M, 199...",[REF] and [REF] developed a different method w...,"[-0.02094898, -0.0007385412, -0.012135256, 0.0..."
4,10.1146/annurev-astro-081811-125615,"Madau et al. (1996 , 1998b ) and Lilly et al. ...",Redshifts z >4 have been confirmed from CO mea...,904,10.1086/305523,20140801,"[1996MNRAS.283.1388M, 1998ApJ...498..106M, 199...",[REF] and [REF] developed a different method w...,"[-0.02094898, -0.0007385412, -0.012135256, 0.0..."


In [43]:
db = MilvusDB()
def most_similar_to_query(example: pd.Series) -> np.ndarray:
    """
    Takes in an example (with 'vector' column already set), and from the candidates
    (returned entities with that doi from the database), returns the vector most similar
    to the example's vector.

    """
    # Converts 'vector' column to rows * dim array, holding the candidate vectors
    candidates = db.select_by_doi(example.target_doi, collection_name="qwen06_chunks")
    candidate_vectors = np.stack(candidates["vector"])
    best_idx = np.argmax(np.dot(candidate_vectors, example["vector"]))
    best_vector = candidate_vectors[best_idx]
    return best_vector

In [44]:
Q = np.array(samples['vector'].tolist()).T
print("Q shape:", Q.shape)

V = np.array([most_similar_to_query(row) for _, row in tqdm(samples.iterrows(), total=len(samples))]).T
print("V shape:", V.shape)

Q shape: (1024, 1265)


100%|██████████| 1265/1265 [00:10<00:00, 124.61it/s]


V shape: (1024, 1265)


In [86]:
M = V @ Q.T
U, S, Vt = np.linalg.svd(M)
R = U @ Vt

print("det R:", np.linalg.det(R))
# if np.linalg.det(R) < 0:
#     D = np.eye(R.shape[0])
#     D[-1, -1] = -1
#     R = U @ D @ Vt
# print("corrected det R:", np.linalg.det(R))

det R: -0.999999999999917


In [87]:
research_dois = set(samples.target_doi)
print(research_dois)

{'10.1086/301107', '10.1093/mnras/stw2819', '10.1088/2041-8205/730/2/L34', '10.1088/0004-637X/695/1/259', '10.1086/163559', '10.1088/2041-8205/711/2/L108', '10.1086/339355', '10.1086/174189', '10.1088/0004-637X/758/1/29', '10.1086/190641', '10.1111/j.1365-2966.2007.12085.x', '10.1086/503551', '10.1088/0004-637X/743/2/161', '10.48550/arXiv.astro-ph/0603179', '10.1086/374266', '10.1088/2041-8205/714/2/L233', '10.1007/BF00661821', '10.1086/160545', '10.3847/1538-4357/abbc6e', '10.1088/0004-637X/723/1/869', '10.1051/0004-6361:20078737', '10.1093/mnras/stz2626', '10.1093/pasj/63.sp2.S493', '10.1111/j.1365-2966.2005.09043.x', '10.1086/339033', '10.1111/j.1365-2966.2003.07115.x', '10.1086/163713', '10.18727/0722-6691/5120', '10.1103/PhysRev.124.925', '10.1051/0004-6361/201220622', '10.1086/339025', '10.1086/190559', '10.1007/BF00145545', '10.1093/mnras/193.2.337', '10.1086/301146', '10.1051/0004-6361:20053949', '10.1093/mnrasl/slt005', '10.1086/491657', '10.1126/science.1112997', '10.1086/176

In [88]:
import random

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


def negative_vectors(example: pd.Series, n=5):
    # Get all target DOIs for this example, sample from those not cited
    rows = samples[(samples.source_doi == example.source_doi) & (samples.sent_idx == example.sent_idx)]
    citation_dois = set(rows.target_doi)
    neg_dois = list(research_dois - citation_dois)
    neg_samples = random.sample(neg_dois, n)
    
    # Get vectors for negative samples
    neg_vectors = np.zeros((n, embedder.dim))
    for i, doi in enumerate(neg_samples):
        records = db.select_by_doi(doi, collection_name="qwen06_chunks")
        sample_record = records.sample(n=1).iloc[0]
        neg_vectors[i] = np.array(sample_record['vector'])
    return neg_vectors


res = negative_vectors(samples.iloc[998])
print(res)

[[-0.04344578  0.0576531  -0.00312176 ... -0.01997764  0.04127572
  -0.01484109]
 [ 0.02326964 -0.02050474 -0.00339893 ... -0.04986836  0.03204567
  -0.03708823]
 [-0.03509971 -0.03268801 -0.00720697 ...  0.0036599   0.01352744
  -0.00130322]
 [-0.02752985 -0.05466044 -0.01236336 ... -0.00509659 -0.01865252
   0.00131505]
 [-0.04398447  0.0296942  -0.01012584 ... -0.02293189  0.02405622
   0.02059095]]


Citation DOIs: {'10.1046/j.1365-8711.2000.03810.x', '10.1086/316394', '10.1086/186883'}

In [89]:
df = db.select_by_doi("10.1046/j.1365-8711.2000.03810.x", collection_name="qwen06_chunks")
row = df.sample(n=1).iloc[0]
print(np.array(row['vector']))

[-0.04126437 -0.03557961 -0.00555971 ... -0.02695249 -0.01948428
 -0.04101685]


In [90]:
differences_to_target = []
differences_to_negative = []
for _, row in samples.iterrows():
    query_vector = row['vector']
    target_vector = most_similar_to_query(row)
    aligned_query_vector = R @ query_vector
    before = cosine_similarity(query_vector, target_vector)
    after = cosine_similarity(aligned_query_vector, target_vector)

    differences_to_target.append(after - before)
    neg_vectors = negative_vectors(row, n=5)
    batch_distance_to_negative = []
    for neg_vector in neg_vectors:
        before_neg = cosine_similarity(query_vector, neg_vector)
        after_neg = cosine_similarity(aligned_query_vector, neg_vector)
        batch_distance_to_negative.append(after_neg - before_neg)
        differences_to_negative.append(after_neg - before_neg)
    print(_)
    print(f"Improvement to target: {after - before:.4f}")
    print(f"Distance to negatives: {np.mean(batch_distance_to_negative):.4f} ± {np.std(batch_distance_to_negative):.4f}")
    print("---")
print(f"Average improvement (target): {np.mean(differences_to_target):.6f} ± {np.std(differences_to_target):.6f}")
print(f"Average improvement (negative): {np.mean(differences_to_negative):.6f} ± {np.std(differences_to_negative):.6f}")

0
Improvement to target: 0.2160
Distance to negatives: -0.0067 ± 0.0178
---
1
Improvement to target: 0.0590
Distance to negatives: -0.0038 ± 0.0167
---
2
Improvement to target: 0.0606
Distance to negatives: 0.0085 ± 0.0113
---
3
Improvement to target: 0.0922
Distance to negatives: 0.0138 ± 0.0115
---
4
Improvement to target: 0.1338
Distance to negatives: 0.0060 ± 0.0106
---
5
Improvement to target: 0.1397
Distance to negatives: 0.0038 ± 0.0200
---
6
Improvement to target: 0.1713
Distance to negatives: 0.0066 ± 0.0168
---
7
Improvement to target: 0.1273
Distance to negatives: 0.0166 ± 0.0221
---
8
Improvement to target: 0.0662
Distance to negatives: 0.0100 ± 0.0139
---
9
Improvement to target: 0.1141
Distance to negatives: 0.0147 ± 0.0178
---
10
Improvement to target: 0.1699
Distance to negatives: -0.0079 ± 0.0206
---
11
Improvement to target: 0.1042
Distance to negatives: -0.0050 ± 0.0226
---
12
Improvement to target: 0.2100
Distance to negatives: 0.0214 ± 0.0152
---
13
Improvement to 

In [91]:
print(R.shape)

(1024, 1024)


In [83]:
z = R @ np.array(row['vector'])
print(z.shape)
print(np.linalg.norm(z))
print(np.linalg.norm(z / np.linalg.norm(z)))

(1024,)
0.9999999694830948
0.9999999999999999


In [92]:
np.save('qwen06_chunks_reflection.npy', R)