In [23]:
import os
from tqdm import tqdm
from datasets import load_dataset
import sys
sys.path.append('../evaluation')
from evaluate import RetrievalSystem, main as evaluate_main
import yaml
import hyde
import json
from vector_store import EmbeddingClient, Document, DocumentLoader

  from .autonotebook import tqdm as notebook_tqdm
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [4]:
with open("/users/christineye/retrieval/config.yaml", 'r') as stream:
    api_key = yaml.safe_load(stream)['openai_api_key']

In [5]:
%load_ext autoreload
%autoreload 2

In [11]:
test_hyde = hyde.HydeRetrievalSystem(embeddings_path = "/users/christineye/retrieval/data/vector_store/embeddings_matrix.npy",
                         documents_path = "/users/christineye/retrieval/data/vector_store/documents.pkl",
                         index_mapping_path = "/users/christineye/retrieval/data/vector_store/index_mapping.pkl", config_path = "/users/christineye/retrieval/config.yaml", 
                                     generate_n = 3, embed_query = True, max_doclen = 100)

Loading embeddings...
Loading documents...
Loading index mapping...
Processing document dates...
Loading metadata...
Data loaded successfully.


In [21]:
query = "What is the stellar mass of the Milky Way?"
arxiv_id = "2301.00001"
top_k = 10
results = test_hyde.retrieve(query, arxiv_id, top_k)

In [9]:
import weighted
test_weighted = weighted.WeightedRetrievalSystem()

Loading embeddings...
Loading documents...
Loading index mapping...
Processing document dates...
Data loaded successfully.
Loading existing index...
Index loaded successfully.


In [10]:
test_weighted.make_bow()
test_weighted.make_embed()

Loading existing index...
Index loaded successfully.
Loading embeddings...
Loading documents...
Loading index mapping...
Processing document dates...
Data loaded successfully.


In [18]:
query = "What is the stellar mass of the Milky Way?"
arxiv_id = "2301.00001"
top_k = 10

In [11]:
astro_meta = load_dataset("JSALT2024-Astro-LLMs/astro_paper_corpus", split = "train")

Downloading data: 100%|██████████| 238M/238M [00:09<00:00, 23.8MB/s]
Downloading data: 100%|██████████| 237M/237M [00:07<00:00, 29.8MB/s]
Downloading data: 100%|██████████| 240M/240M [00:08<00:00, 29.7MB/s]
Downloading data: 100%|██████████| 235M/235M [00:07<00:00, 32.5MB/s]
Downloading data: 100%|██████████| 233M/233M [00:07<00:00, 32.6MB/s]
Downloading data: 100%|██████████| 237M/237M [00:07<00:00, 30.7MB/s]
Generating train split: 100%|██████████| 271544/271544 [00:05<00:00, 54228.25 examples/s]


In [23]:
test_weighted.bow.preprocess_text(astro_meta[0]['keyword'][5])

'fornax dark matter content'

In [36]:
keys = list(paper.keys())
keys.remove('abstract')
keys.remove('introduction')
keys.remove('conclusions')

In [42]:
metadata = {}
for paper in astro_meta:
    id_str = paper['arxiv_id']
    metadata[id_str] = {key: paper[key] for key in keys}

In [27]:
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords', quiet=True)
stopwords = set(stopwords.words('english'))

In [21]:
with open('../data/vector_store/metadata.json', 'r') as f:
    metadata = json.load(f)

In [28]:
def make_keyword_index(metadata):
    keyword_index = {}
    
    for i, index in tqdm(enumerate(metadata)):
        paper = metadata[index]
        for keyword in paper['keyword_search']:
            term = ' '.join(word for word in keyword.lower().split() if word.lower() not in stopwords)
            if term not in keyword_index:
                keyword_index[term] = []
            
            keyword_index[term].append(paper['arxiv_id'])
    
    return keyword_index

In [29]:
keyword_index = make_keyword_index(metadata)

271540it [00:11, 24324.32it/s]


In [30]:
with open('../data/vector_store/keyword_index.json', 'w') as f:
    json.dump(keyword_index, f)

In [14]:
import keywords

In [21]:
keyword = keywords.KeywordRetrievalSystem()

Loading existing index...
Index loaded successfully.


In [22]:
keyword.retrieve(query, arxiv_id, top_k)

['1607.08157',
 '1607.05040',
 '1607.05324',
 '1607.06479',
 '1607.02546',
 '1607.08765',
 '1607.01483',
 '1607.03625',
 '1607.08672',
 '1607.04418',
 '1607.01049',
 '1607.06325',
 '1607.06462',
 '1607.08761',
 '1607.03506',
 '1607.05283',
 '1607.07328',
 '1607.08755',
 '1607.01023',
 '1607.01798',
 '1607.03497',
 '1607.01806',
 '1607.00374',
 '1607.00374',
 '1607.01009',
 '1607.04275',
 '1607.01468',
 '1607.02606',
 '1607.02606',
 '1607.04278',
 '1403.4638',
 '1403.6720',
 '1403.0324',
 '1403.0954',
 '1403.0427',
 '1403.4606',
 '1403.3539',
 '1403.6827',
 '1403.5717',
 '1403.1280',
 '1403.1212',
 '1403.2475',
 '1403.1561',
 '1403.5963',
 '1403.5053',
 '1403.2389',
 '1403.4352',
 '1403.3401',
 '1403.4960',
 '1403.2733',
 '1403.4215',
 '1403.5611',
 '1403.6018',
 '1403.3121',
 '1403.6111',
 '1403.6111',
 '1403.6621',
 '1403.0576',
 'astro-ph9804053_arXiv.txt',
 'astro-ph9804011_arXiv.txt',
 'astro-ph9804123_arXiv.txt',
 'hep-ph9804285_arXiv.txt',
 'astro-ph9804026_arXiv.txt',
 'astro-ph

In [47]:
test = hyde_reranking.HydeCohereRetrievalSystem(config_path = "../config.yaml")

Loading embeddings...
Loading documents...
Loading index mapping...
Processing document dates...
Loading metadata...
Data loaded successfully.


In [57]:
test.weight_citation = False
test.retrieve(query, arxiv_id, top_k)

['0801.1023',
 '1711.01453',
 '2101.05821',
 '1001.3411',
 '1503.06065',
 '1703.08585',
 '1611.01545',
 '1510.06665',
 '2003.04925',
 '1908.00116']

In [56]:
query = "What are the primary computational methods used in modern cosmological simulations, and what are some notable examples of each approach?"
test.weight_citation = True
test.retrieve(query, arxiv_id, top_k)

['0801.1023',
 '1711.01453',
 '2101.05821',
 '1001.3411',
 '1703.08585',
 '1510.06665',
 '2003.04925',
 '1908.00116',
 'astro-ph0611863_arXiv.txt',
 'astro-ph0005502_arXiv.txt']

In [55]:
import spacy
from collections import Counter
from string import punctuation
from nltk.corpus import stopwords
import nltk

nltk.download('stopwords', quiet=True)
stopwords = set(stopwords.words('english')) 
nlp = spacy.load("en_core_web_sm")
spacy.cli.download('en_core_web_sm')
nlp.add_pipe("textrank")

def get_keywords(text):
    result = []
    pos_tag = ['PROPN', 'ADJ', 'NOUN'] 
    doc = nlp(text.lower()) 
    for token in doc:
        if(token.text in nlp.Defaults.stop_words or token.text in punctuation):
            continue
        if(token.pos_ in pos_tag):
            result.append(token.text)
    return result

def parse_doc(text, nret = 10):
    text = ' '.join(word for word in text.split() if word.lower() not in stopwords)
    local_kws = []
    doc = nlp(text)
    # examine the top-ranked phrases in the document
    for phrase in doc._.phrases[:nret]:
        # print(phrase.text)
        local_kws.append(phrase.text.lower())
    return local_kws

In [68]:
parse_doc('.. asldkjf .asd.')

['.. asldkjf']