In [177]:
import io
import sqltables.sqlite3
import numpy as np
import json
import sklearn.decomposition
import networkx as nx

In [178]:
db = sqltables.sqlite3.Database(uri=True)
db.execute("attach 'file:submissions.sqlite3?mode=ro' as submissions")

In [179]:
metadata = {
    row.arxiv_id: row for row in db.open_table("submissions")
}

In [180]:
paper_embeddings = db.open_table("specter").view("""
select arxiv_id, min(paper_info) as paper_info from _ group by arxiv_id
""")

In [181]:
[[N]] = paper_embeddings.view("select count(*) from _")
proto_row = next(iter(paper_embeddings))
proto_data = json.loads(proto_row.paper_info)
arxiv_ids = [None] * N
embeddings = np.zeros((N, len(proto_data["embedding"]["vector"])))
N

16923

In [182]:
for i, row in enumerate(paper_embeddings):
    arxiv_ids[i] = row.arxiv_id
    data = json.loads(row.paper_info)
    embedding = np.array(data["embedding"]["vector"])
    embeddings[i, :] = embedding

In [183]:
embeddings

array([[-0.66105336,  0.34472892, -0.93162173, ...,  0.12510972,
         0.5625037 ,  0.4185591 ],
       [ 0.61239731,  0.18827562,  0.19636126, ...,  0.43264362,
         0.11558405, -0.07368433],
       [-0.03837255,  0.00634928,  0.78059745, ...,  0.46514544,
        -0.27353293,  0.42210951],
       ...,
       [-0.13790952,  0.07871211, -0.29383036, ..., -0.98424453,
        -1.48461008, -0.157213  ],
       [-0.53119463,  0.26525766,  0.03041879, ..., -0.32618117,
        -0.70937699, -0.29410005],
       [-0.24245238,  1.2394197 , -0.1845255 , ..., -0.92205119,
        -0.01297263,  1.35307848]])

In [184]:
pca = sklearn.decomposition.PCA(32)
reduced_embeddings = pca.fit_transform(embeddings)

In [185]:
reduced_embeddings.nbytes / 2**10

4230.75

In [186]:
%%time
k = 4
nearest_neighbours = {}
for i in range(reduced_embeddings.shape[0]):
    emb = reduced_embeddings[i, :]
    d2 = np.sum((emb[None, :] - reduced_embeddings)**2, axis=-1)
    nearest_neighbours[i] = np.argsort(d2)[1:(k+1)].tolist()

CPU times: user 29.3 s, sys: 2.78 s, total: 32.1 s
Wall time: 32.1 s


In [187]:
nearest_neighbours

{0: [16810, 8945, 15625, 10195],
 1: [11629, 6540, 10645, 9792],
 2: [10227, 8956, 11629, 12613],
 3: [6308, 2823, 1701, 9349],
 4: [16205, 5445, 1772, 14608],
 5: [13644, 9442, 2164, 2030],
 6: [6264, 12069, 3881, 7108],
 7: [8400, 3433, 460, 10691],
 8: [2296, 5977, 5832, 12807],
 9: [7730, 15412, 7873, 2656],
 10: [15347, 962, 10014, 4214],
 11: [15184, 16014, 6712, 5946],
 12: [2320, 6845, 8848, 105],
 13: [2387, 9095, 9182, 6157],
 14: [1737, 10471, 322, 5195],
 15: [10907, 10235, 11759, 12960],
 16: [10544, 7335, 6789, 16603],
 17: [11882, 14052, 12078, 6826],
 18: [8510, 10186, 7784, 14173],
 19: [8797, 14963, 7164, 8780],
 20: [586, 13795, 5909, 12486],
 21: [5057, 5749, 8878, 10797],
 22: [2297, 5486, 13005, 2331],
 23: [6401, 9057, 16638, 102],
 24: [1905, 1550, 15274, 1317],
 25: [9421, 3765, 328, 4477],
 26: [11423, 6138, 2949, 16446],
 27: [12340, 13837, 15940, 12018],
 28: [13120, 11110, 14706, 9699],
 29: [913, 2153, 1510, 14234],
 30: [2894, 6173, 5028, 3422],
 31: [528

In [188]:
nn_graph = nx.from_dict_of_lists(nearest_neighbours)

In [189]:
for u, v in nn_graph.edges:
    dist = np.linalg.norm(embeddings[u] - embeddings[v])
    nn_graph.edges[(u, v)]["weight"] = np.exp(-dist**2)

In [190]:
mst = nx.minimum_spanning_tree(nn_graph, weight="weight")
len(list(nx.connected_components(mst)))

1

In [191]:
root = next(n for n in mst.nodes if len(mst[n]) == 1)
root

6

In [192]:
nodes = list(nx.depth_first_search.dfs_preorder_nodes(mst, root))
nodes

[6,
 3881,
 7108,
 705,
 10412,
 2995,
 7805,
 10307,
 13860,
 12546,
 2564,
 1222,
 4988,
 1671,
 897,
 619,
 15063,
 1469,
 2854,
 13769,
 9211,
 13684,
 13821,
 10647,
 13671,
 10945,
 9213,
 11788,
 6120,
 11742,
 14975,
 12559,
 13206,
 9871,
 16079,
 3241,
 10805,
 14241,
 9652,
 3903,
 9631,
 806,
 16728,
 2680,
 1923,
 6061,
 6869,
 5710,
 5216,
 6180,
 11460,
 6994,
 8390,
 6011,
 901,
 16457,
 12856,
 6590,
 9264,
 6296,
 11599,
 6315,
 12626,
 5505,
 11483,
 4104,
 1617,
 3715,
 11555,
 3102,
 8708,
 15728,
 9099,
 4578,
 6697,
 12323,
 14950,
 5944,
 13946,
 16186,
 4267,
 4609,
 16587,
 6122,
 2837,
 14349,
 1168,
 10004,
 9890,
 1771,
 3580,
 1418,
 6169,
 14037,
 8099,
 11457,
 12026,
 4182,
 7432,
 1867,
 5372,
 4844,
 3859,
 10067,
 2142,
 10149,
 6341,
 10191,
 15711,
 14539,
 12948,
 9659,
 10687,
 12859,
 5407,
 15597,
 150,
 7602,
 11173,
 3358,
 7736,
 1528,
 4917,
 10679,
 11390,
 771,
 3751,
 14332,
 12506,
 12820,
 12361,
 13851,
 8264,
 750,
 12974,
 5091,
 17

In [193]:
def index_to_title(index):
    return metadata[arxiv_ids[index]].title

In [194]:
[index_to_title(n) for n in nodes[:20]]

['Deep Reinforcement Learning for Field Development Optimization',
 'Development and Validation of an AI-Driven Model for the La Rance Tidal Barrage: A Generalisable Case Study',
 'Deep reinforcement learning for optimal well control in subsurface systems with uncertain geology',
 'Reinforcement Learning-Based Automatic Berthing System',
 'Sim2real for Reinforcement Learning Driven Next Generation Networks',
 'Hyperparameter Tuning for Deep Reinforcement Learning Applications',
 'Self-scalable Tanh (Stan): Faster Convergence and Better Generalization in Physics-informed Neural Networks',
 'PyTSK: A Python Toolbox for TSK Fuzzy Systems',
 'Neural Networks for Scalar Input and Functional Output',
 'Single Model Uncertainty Estimation via Stochastic Data Centering',
 'Kernel Methods and Multi-layer Perceptrons Learn Linear Models in High Dimensions',
 'Acceleration techniques for optimization over trained neural network ensembles',
 'DC and SA: Robust and Efficient Hyperparameter Optimiza

In [195]:
differences = np.zeros([len(nodes)])
for i, node_id in enumerate(nodes):
    if i == 0:
        continue
    prev_node_id = nodes[i-1]
    differences[i] = np.linalg.norm(embeddings[node_id] - embeddings[prev_node_id])

In [196]:
np.min(differences[1:])

3.915607913611192

In [197]:
if "paper_ordering" in db.tables:
    db.drop_table("paper_ordering")
paper_ordering = db.create_table(name="paper_ordering", column_names=["arxiv_id", "position", "difference"])
paper_ordering.insert([(arxiv_ids[node_id], i, differences[i]) for i, node_id in enumerate(nodes)])
paper_ordering

|arxiv\_id|position|difference|
|-|-|-|
|\'2008\.12627\'|0|0\.0|
|\'2202\.05347\'|1|14\.24064240877115|
|\'2203\.13375\'|2|13\.015359264473641|
|\'2112\.01879\'|3|16\.09389773454536|
|\'2206\.03846\'|4|15\.962007706115013|
|\'2201\.11182\'|5|11\.588314536070962|
|\'2204\.12589\'|6|14\.239815411587543|
|\'2206\.03310\'|7|13\.20013839368352|
|\'2208\.05776\'|8|13\.602995439631188|
|\'2207\.07235\'|9|12\.40296663183529|
|\'2201\.08082\'|10|11\.297199589351195|
|\'2112\.07007\'|11|14\.360290470467469|
|\'2202\.11841\'|12|11\.63804186120914|
|\'2112\.12589\'|13|18\.516266243250346|
|\'2112\.04682\'|14|16\.339446227742688|
|\'2112\.00579\'|15|14\.383876779052635|
|...|...|...|


In [198]:
[differences] = zip(*db.query("select difference from paper_ordering"))

In [199]:
diff_quantiles = np.quantile(differences, [0.75, 0.9])
def diff_color(diff):
    i = np.searchsorted(diff_quantiles, diff)
    p = 1 - 0.5*(i / len(diff_quantiles))
    return f"rgb({p*100}%, {p*100}%, {p*100}%)"

In [200]:
diff_color(10)

'rgb(100.0%, 100.0%, 100.0%)'

In [201]:
def quote_html(text):
    return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")

In [202]:
buf = io.StringIO()
buf.write("""
<!doctype html>
<style>
body { margin: 20px; }
</style>
<script>
function toggle(arxiv) {
  let elt = document.getElementById(arxiv);
  console.log(elt, elt.style.display);
  if(elt.style.display == "block") {
    elt.style.display = "none";
  } else {
    elt.style.display = "block";
  }
}
</script>
""")
rows = db.query("""
select * from paper_ordering join submissions using (arxiv_id) order by position limit 128
""")
for row in rows:
    arxiv = quote_html(row.arxiv_id)
    title = quote_html(row.title)
    authors = quote_html(row.authors)
    body_q = quote_html(row.abstract)
    url = quote_html(row.url)
    diff = row.difference
    color = diff_color(diff)
    buf.write(f"<div>")
    buf.write(f"<div style='border-top: 2px solid {color}' onclick='toggle(\"{arxiv}\")'>{title}</div>\n")
    buf.write(f"<div id='{arxiv}' style='display: none; margin-left: 20px'>")
    buf.write(f"<div>Authors: {authors}</div>")
    buf.write(f"<div style='padding-top: 10px; width: 80ex'>{body_q}</div><div><a href='{url}'>{arxiv}</a></div>\n")
    buf.write("</div>")
    buf.write("</div>")

with open("arxiv.html", "w+") as f:
    f.write(buf.getvalue())