In [1]:
paper_data_file = "big_paper_data.csv"
citations_file = "big_citations.csv"

In [25]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json

In [3]:
paper_data = pd.read_csv(paper_data_file)
print(paper_data.columns)
citations = pd.read_csv(citations_file)
print(citations.columns)

Index(['paperId', 'title', 'embedding'], dtype='object')
Index(['source', 'target'], dtype='object')


In [4]:
def get_embedding_matrix(series):
    return np.vstack(series.apply(eval).apply(np.array).values)

all_embeddings = get_embedding_matrix(paper_data.embedding)
all_embeddings.shape

(1524, 768)

In [5]:
# assemble the dataset
dataset = []
for source_paper in np.unique(citations.source):
    citation_group = citations[citations.source == source_paper].target.values
    relevant_mask = paper_data.paperId.apply(set(citation_group).__contains__).values
    relevant_embeddings = all_embeddings[relevant_mask]
    irrelevant_embedding = all_embeddings[~relevant_mask]
    dataset.append({
        "source_paper": source_paper,
        "relevant_embeddings": relevant_embeddings,
        "irrelevant_embeddings": irrelevant_embedding
    })

In [22]:
# more independent irrelevant embeddings
for d in dataset:
    source = d["source_paper"]
    sampled_irrelevant = []
    for f in dataset:
        if f["source_paper"] != source:
            sample_idx = np.random.randint(d["irrelevant_embeddings"].shape[0])
            sampled_irrelevant.append(d["irrelevant_embeddings"][sample_idx, :])
    d["sampled_irrelevant"] = np.vstack(sampled_irrelevant)

In [30]:
for entry in dataset:
    for array_name in ["relevant_embeddings", "irrelevant_embeddings", "sampled_irrelevant"]:
        entry[array_name] = [list(row) for row in entry[array_name]]

In [31]:
with open("data/dataset.json", "w") as f:
    json.dump(dataset, f)