In [None]:
import re, pickle
from itertools import chain

import psycopg2
from datasets import load_dataset

from utils import strip_punc

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
with open('stopwords-en.txt', mode='r') as f:
    stopwords = set([line.strip('\n') for line in f.readlines()])
with open('total_counts.pkl', mode='rb') as f:
    total_counts = pickle.load(f)
vocabulary = [k for k in total_counts.keys() if k not in stopwords]

In [None]:
variations_prompt = """Generate reasonable variations for each word in the following text. For example, for \
the word philosophy, you might respond philosophies, philosopher, philosophizing. Your only \
output should be the related words: write each group as a comma separated list and use line \
breaks to separate groups.

Text:
"""

In [None]:
def parse_variations(text):
    output = list(chain(*[line.strip().split(', ') for line in text.split('\n')]))
    return [re.sub('[\'|\"]', '', w) for w in output]

In [None]:
def query_db(queries):
    conn = psycopg2.connect("dbname=postgres user=postgres")
    cur = conn.cursor()
    hashes = ', '.join([str(hash(word)) for word in queries])
    keywords_query = f'''
        SELECT chunks
        FROM keywords
        WHERE word_hash in ({hashes})
    '''
    cur.execute(keywords_query)
    chunks = cur.fetchall()
    chunk_ids = ', '.join(list(chain(*[chunk[0] for chunk in chunks])))
    chunks_query = f'''
        SELECT article_id, chunk_index
        FROM wikichunks
        WHERE id in ({chunk_ids})
    '''
    cur.execute(chunks_query)
    chunks = cur.fetchall()
    cur.close()
    conn.close()
    return chunks

In [None]:
ds = load_dataset("wikimedia/wikipedia", "20231101.en")['train']
def retrieve(article_id, chunk_id):
    return strip_punc(ds[article_id]['text']).split('\n')[chunk_id]

In [None]:
print(variations_prompt)

In [None]:
variations = '''compass, compasses, compass-like
potter, potters, pottery, potting, potter's
wheel, wheels, wheeled, wheelwright  
saw, saws, sawing, sawed
invention, inventions, inventive, inventing, inventor, invented
attribute, attributes, attributed, attributing
Athenian, Athenians  
youth, youths, youthful, youthfulness
'''

In [None]:
results = query_db(parse_variations(variations))
rag_text = [retrieve(*result) for result in results]