# New collaborations are driven by search for expertise

### Brief overview


### Imports & Global Variables

In [35]:
import sys
import os
sys.path.insert(0, os.path.abspath(".."))

import numpy as np
import polars as pl
import math
from scipy.spatial.distance import cosine
from scipy.stats import wasserstein_distance
from collections import defaultdict
from itertools import combinations
from tqdm import tqdm
from box import Box
from util.postgres import create_sqlalchemy_engine, query_polars

In [2]:
# -------------------- GLOBAL VARIABLES --------------------
PATH_TO_CONFIG_FILE = '../config.yaml'

# -------------------- LOAD CONFIGURATION --------------------
# Load the configuration file
config = Box.from_yaml(filename=PATH_TO_CONFIG_FILE)
# Initialize a BigQuery client
pg_engine = create_sqlalchemy_engine(
    username=config.POSTGRES.USERNAME,
    password=config.POSTGRES.PASSWORD,
    host=config.POSTGRES.HOST,
    port=config.POSTGRES.PORT,
    database=config.POSTGRES.DATABASE,
    schema=config.POSTGRES.SCHEMA
)
# Set numpy random seed
np.random.seed(config.RANDOM_SEED)
batch_size = 10000

### Querying the data

In [6]:
%%time
with pg_engine.raw_connection().cursor() as cur:
    cur.execute("""
        SELECT article_id, article_embedding
        FROM g_included_article_embedding
    """)
    # Initialize the Polars DataFrame
    article_embedding_df: pl.DataFrame = pl.DataFrame()
    ix = 0
    # Fetch in chunks
    while True:
        rows = cur.fetchmany(size=batch_size)
        if not rows:
            break
        # Append the rows to the Polars DataFrame
        df_chunk = pl.DataFrame(rows, schema=["article_id", "article_embedding"], orient="row")
    
        # Concatenate chunk with the master DataFrame
        article_embedding_df = pl.concat([article_embedding_df, df_chunk], how="vertical")
        print(f"Rows fetched {batch_size} for batch {ix}")
        ix += 1

Rows fetched 10000 for batch 0
Rows fetched 10000 for batch 1
Rows fetched 10000 for batch 2
Rows fetched 10000 for batch 3
Rows fetched 10000 for batch 4
Rows fetched 10000 for batch 5
Rows fetched 10000 for batch 6
Rows fetched 10000 for batch 7
Rows fetched 10000 for batch 8
Rows fetched 10000 for batch 9
Rows fetched 10000 for batch 10
Rows fetched 10000 for batch 11
Rows fetched 10000 for batch 12
Rows fetched 10000 for batch 13
Rows fetched 10000 for batch 14
Rows fetched 10000 for batch 15
Rows fetched 10000 for batch 16
Rows fetched 10000 for batch 17
Rows fetched 10000 for batch 18
Rows fetched 10000 for batch 19
Rows fetched 10000 for batch 20
Rows fetched 10000 for batch 21
Rows fetched 10000 for batch 22
Rows fetched 10000 for batch 23
Rows fetched 10000 for batch 24
Rows fetched 10000 for batch 25
Rows fetched 10000 for batch 26
Rows fetched 10000 for batch 27
Rows fetched 10000 for batch 28
Rows fetched 10000 for batch 29
Rows fetched 10000 for batch 30
Rows fetched 10000

In [53]:
%%time
sql_query = f"""
SELECT f.article_id,
       f.author_id,
       f.co_author_id,
       f.is_new_author_pair,
       a.article_publication_dt
FROM fct_new_author_pair f
    INNER JOIN dim_article a 
        ON a.article_id = f.article_id
WHERE f.author_id IN (SELECT author_id FROM g_included_author)
    AND f.co_author_id IN (SELECT author_id FROM g_included_author)
"""
with pg_engine.connect() as conn:
    df_author_pair = query_polars(conn=conn, query_str=sql_query)

CPU times: user 15.1 s, sys: 2.46 s, total: 17.6 s
Wall time: 29.2 s


In [54]:
%%time
sql_query = f"""
SELECT DISTINCT article_id,
                author_id,
                article_publication_dt
FROM fct_collaboration
WHERE author_id IN (SELECT author_id FROM g_included_author)
ORDER BY article_publication_dt ASC
"""

with pg_engine.connect() as conn:
    df_collab = query_polars(conn=conn, query_str=sql_query)

CPU times: user 3.37 s, sys: 710 ms, total: 4.08 s
Wall time: 9.13 s


### Prepocessing for optimization

In [22]:
# Build a lookup dictionary for article embeddings
article_embeddings = {}
for row in tqdm(article_embedding_df.iter_rows(named=True)):
    art_id = row["article_id"]
    emb = row["article_embedding"]
    
    if isinstance(emb, list):
        emb = np.array(emb, dtype=np.float32)
    article_embeddings[art_id] = emb

354102it [01:09, 5125.25it/s] 


In [56]:
# Build a lookup dictionary for author articles
author_articles_map = defaultdict(list)
for row in tqdm(df_collab.iter_rows(named=True)):
    a_id = row["author_id"]
    art_id = row["article_id"]
    pub_dt = row["article_publication_dt"]
    author_articles_map[a_id].append((art_id, pub_dt))

# Sort each author's articles by publication date to handle "before/after" queries quickly.
for a_id in author_articles_map:
    author_articles_map[a_id].sort(key=lambda x: x[1])  

973511it [00:07, 137012.12it/s]


In [58]:
# Build a dictionary to check new/existing author pairs
is_new_map = {}

for row in tqdm(df_author_pair.iter_rows(named=True)):
    art_id = row["article_id"]
    a1 = row["author_id"]
    a2 = row["co_author_id"]
    new_flag = row["is_new_author_pair"]
    
    key = (art_id, frozenset([a1, a2]))
    is_new_map[key] = new_flag

2199026it [00:34, 64080.84it/s] 


In [59]:
# Computing top-5 closest articles for an author relative to article p
def get_top_k_closest_articles(article_p_id: int, author_id: int, k: int = 5) -> list:
    """
    For a given article p (article_p_id), find the top-k closest articles
    from the given author based on cosine distance in embedding space.
    Returns a list of (article_id, distance).
    """
    if article_p_id not in article_embeddings:
        return []
    
    emb_p = article_embeddings[article_p_id]
    
    # Collect all the author's articles
    # (In practice, you might exclude article_p_id itself if you don't want self-comparison,
    #  or you might only consider articles before p's date, etc. - depends on your logic.)
    author_article_list = author_articles_map[author_id]
    
    # If your logic requires "articles published before p", filter accordingly:
    # Get publication_dt of p
    # (You can fetch it from df_collab or df_author_pair if needed.)
    # For demonstration, let's do a naive approach ignoring time:
    
    distances = []
    for (art_id, _) in author_article_list:
        if art_id not in article_embeddings:
            continue
        emb_a = article_embeddings[art_id]
        dist = cosine(emb_p, emb_a)  # or 1 - dot(...) / (||p||*||a||) if you want to do it manually
        distances.append((art_id, dist))
    
    # Sort by distance ascending (closest first)
    distances.sort(key=lambda x: x[1])
    
    # Return top k
    return distances[:k]

In [60]:
# Computing the distance measure for each pair (a_i, a_j) given article p
#    - We retrieve top-5 closest articles from a_i and from a_j
#    - We do a cross-comparison for all 5 x 5 pairs in embeddings and measure 
#      the average distance (or min distance, or however you define it).


def compute_pair_distance(article_p_id: int, a_i: int, a_j: int) -> float:
    """
    Compute the "distance" for the pair (a_i, a_j) given an article p.

    Steps (naive example):
      1) Retrieve top 5 articles of a_i that are closest to p
      2) Retrieve top 5 articles of a_j that are closest to p
      3) For each combination of (article_i, article_j) from these top-5 sets, 
         compute the cosine distance of their embeddings
      4) Return the average of all 25 distances (or any other chosen metric).
    """
    top_5_i = get_top_k_closest_articles(article_p_id, a_i, k=5)  # [(art_id, dist_p_i), ...]
    top_5_j = get_top_k_closest_articles(article_p_id, a_j, k=5)
    
    # If for some reason we can't get any top articles, handle gracefully:
    if not top_5_i or not top_5_j:
        return math.nan
    
    # Convert to just article IDs
    top_5_i_ids = [x[0] for x in top_5_i]
    top_5_j_ids = [x[0] for x in top_5_j]
    
    distances_ij = []
    for art_id_i in top_5_i_ids:
        emb_i = article_embeddings.get(art_id_i)
        if emb_i is None:
            continue
        for art_id_j in top_5_j_ids:
            emb_j = article_embeddings.get(art_id_j)
            if emb_j is None:
                continue
            d_ij = cosine(emb_i, emb_j)
            distances_ij.append(d_ij)
    
    if len(distances_ij) == 0:
        return math.nan
    
    return float(np.max(distances_ij))

In [61]:
# For memory reasons, let's assume we chunk the iteration over df_collab by article_id
# or simply group by article_id. We'll do an example with group_by, but watch out for 
# memory usage if the dataset is huge.
grouped = (df_collab
           .group_by("article_id", maintain_order=True).agg(
               [
                   pl.col("author_id").alias("authors_of_p"),
                   pl.col("article_publication_dt").first().alias("p_pub_dt"),
               ]
           )
           .filter(pl.col("article_id").is_in(article_embeddings.keys()))
          )

# Now grouped has shape ~ (#unique articles, 3)
#   [article_id, authors_of_p, p_pub_dt]

In [63]:
distances_new = []
distances_existing = []

for row in tqdm(grouped.iter_rows(named=True)):
    art_p_id = row["article_id"]
    authors_of_p = row["authors_of_p"]
    # p_pub_dt = row["p_pub_dt"]  # if you need the date for logic
    
    # All pair combinations among authors_of_p
    # If an article has N authors, that's N*(N-1)/2 pairs
    for (a_i, a_j) in combinations(authors_of_p, 2):
        # We'll build the lookup key for is_new_map
        key = (art_p_id, frozenset([a_i, a_j]))
        if key not in is_new_map:
            # Possibly you skip or assume "existing" if not found
            continue
        
        is_new = is_new_map[key]
        
        # Compute distance for this pair
        pair_dist = compute_pair_distance(art_p_id, a_i, a_j)
        if math.isnan(pair_dist):
            # skip if we couldn't compute anything
            continue
        
        if is_new:
            distances_new.append(pair_dist)
        else:
            distances_existing.append(pair_dist)

41616it [10:02, 69.03it/s]  


KeyboardInterrupt: 

In [64]:
# Convert to numpy arrays for convenience
distances_new = np.array(distances_new, dtype=np.float32)
distances_existing = np.array(distances_existing, dtype=np.float32)

# You might also want to do some basic cleaning, e.g., remove outliers or check length
if len(distances_new) == 0 or len(distances_existing) == 0:
    print("Not enough data to compute Wasserstein distance.")
else:
    w_distance = wasserstein_distance(distances_new, distances_existing)
    print(f"Wasserstein-1 distance (new vs existing): {w_distance:.4f}")

Wasserstein-1 distance (new vs existing): 0.0243


In [65]:

from scipy.stats import ks_2samp
ks_stat, p_value = ks_2samp(distances_new, distances_existing)
print(f"KS test stat: {ks_stat}, p-value: {p_value}")

KS test stat: 0.2375562367275667, p-value: 0.0
