In [1]:
import pandas as pd
import numpy as np
import pickle as pkl
import chromadb
from chromadb.utils import embedding_functions
from chromadb.config import Settings


def load_pkl(path):
    with open(path, "rb") as f:
        return pkl.load(f)

In [2]:
# setup database
chroma_client = chromadb.PersistentClient(path="../data/vecdb")
# creating embedding function
sentence_transformer_ef = (
    embedding_functions.SentenceTransformerEmbeddingFunction(
        model_name="mixedbread-ai/mxbai-embed-large-v1"
    )
)

  from .autonotebook import tqdm as notebook_tqdm


In [28]:
table = chroma_client.get_collection('golden_vecs',embedding_function=sentence_transformer_ef)

In [4]:
# load data
zb_golden = pd.read_csv('../data/zbmath_golden_lookup.csv')

In [92]:
conts[conts["zbMATH_ID"] == "1363213"].Title.item()

'Indecomposable higher Chow cycles on products of elliptic curves'

In [13]:
collection.get("1002951", include=["embeddings"])

{'ids': ['1002951'],
 'embeddings': [[0.04902418330311775,
   -0.23463520407676697,
   0.1181865856051445,
   0.445631206035614,
   -0.42419227957725525,
   -0.576542854309082,
   -0.6553453803062439,
   0.10104072839021683,
   0.11176996678113937,
   0.6563395261764526,
   0.6411189436912537,
   0.3685266077518463,
   -0.249489426612854,
   -0.3435029089450836,
   0.11031285673379898,
   -0.0525948740541935,
   -0.3492167890071869,
   -0.5255294442176819,
   -1.538670301437378,
   -0.6092823147773743,
   0.19667401909828186,
   -0.05373707413673401,
   -1.0606199502944946,
   -1.3551607131958008,
   -0.28651732206344604,
   0.5230482816696167,
   0.44867968559265137,
   0.4122858941555023,
   1.2898894548416138,
   0.29341787099838257,
   0.765991747379303,
   -0.5104576945304871,
   -0.06031504273414612,
   -0.196732759475708,
   0.12677977979183197,
   0.3484562933444977,
   0.408886194229126,
   -0.19311600923538208,
   -0.4455784857273102,
   -0.7900330424308777,
   0.290980398654

In [14]:
conts[conts["zbMATH_ID"] == "1363213"].Title.item()

ValueError: can only convert an array of size 1 to a Python scalar

In [18]:
recs = load_pkl('../data/zbmath_golden_recs.pickle')

In [23]:
zb_golden = zb_golden.set_index('id')

In [39]:
list(zb_golden.loc[list(recs.keys())].text)

['Let \\(k\\) be a field, and let \\(E\\) be a free \\(k\\)-bimodule, or in other words, a two-sided \\(k\\)-vector space. The author introduces a concept of admissible two-sided \\(k\\)-vector spaces and defines for such a space a noncommutative analog of the symmetric algebra. Then he computes the skew fields of fractions for some classes of admissible two-sided vector spaces. The components of degree zero of these skew fields can be considered as function fields of noncommutative ruled surfaces (cf. [\\textit{M. Artin, M. Van den Bergh}, J. Algebra 133, No. 2, 249-271 (1990; Zbl 0717.14001)]). As an application a clear description of the birational equivalence classes of certain types of such surfaces is obtained.',
 "Let \\(X\\) be a smooth complex projective variety, and let \\(\\text{CH}^k(X,1)\\) be the higher Chow group. A higher Chow cycle is said to be decomposable if it lies in the image of the intersection product, i.e., the group of decomposable higher Chow cycles is \n\\[

In [42]:
preds = load_pkl('../data/predictions/preds.pickle')

In [None]:
def P_R_F1_at_k(preds,k):

    recalls = []
    precisions = []
    f1s = []
    for truth, predictions in preds:
        gold_recs = set(truth)
        if len(gold_recs) == 0:
            continue
        predicted = set(predictions[:k])

        hits = len(predicted.intersection(gold_recs))

        # not sure about this. If you hit 1 in k=3 and you have 10 gold recs you have recall .33
        # but if you hit 1 in k=10 with gold recs 10 you have recal .1
        # recall = hits / (len(gold_recs) if len(gold_recs) <= k else k)
        recall = hits / len(gold_recs)
        precision = hits / k
        f1 = (
            (2 * (precision * recall) / (precision + recall))
            if (precision + recall) > 0
            else 0
        )

        recalls.append(recall)
        precisions.append(precision)
        f1s.append(f1)

    return np.mean(precisions), np.mean(recalls), np.mean(f1s)

In [43]:
preds

[([4181495, 930151, 1579464, 5083606, 6338806],
  ['1566951',
   '6338806',
   '5083606',
   '917671',
   '5154884',
   '1448638',
   '6572608',
   '1104115',
   '1579464',
   '6095689',
   '1013949'])]

In [40]:
table.query(
    query_texts=list(zb_golden.loc[list(recs.keys())].text),
    n_results=11,
)["ids"]

[['1566951',
  '6338806',
  '5083606',
  '917671',
  '5154884',
  '1448638',
  '6572608',
  '1104115',
  '1579464',
  '6095689',
  '1013949'],
 ['1363213',
  '1445144',
  '1801581',
  '1036371',
  '438592',
  '949890',
  '1000561',
  '2051858',
  '2165994',
  '5535059',
  '5680965'],
 ['1308161',
  '5638157',
  '1356576',
  '1104115',
  '5007259',
  '4193896',
  '1239817',
  '1448638',
  '1138672',
  '3989438',
  '1443593'],
 ['1303018',
  '5354085',
  '427914',
  '951967',
  '5120555',
  '1443593',
  '1026489',
  '3989438',
  '6841544',
  '5199713',
  '2067093'],
 ['1591097',
  '1663809',
  '1621343',
  '1943338',
  '5049067',
  '1915690',
  '2136591',
  '3867686',
  '1758339',
  '1036371',
  '1445144'],
 ['2105416',
  '1059629',
  '6216299',
  '1544078',
  '6099398',
  '6265550',
  '6298842',
  '6224865',
  '1910606',
  '1837308',
  '5288649'],
 ['1013949',
  '6677815',
  '1864745',
  '5990842',
  '6029101',
  '1491052',
  '1104115',
  '1738275',
  '6264538',
  '6355533',
  '4134258'

'default_database'