In [1]:
import os

## For vector database
import lancedb

## For data handling
import polars as pl

## For embedding model
from langchain_community.embeddings import HuggingFaceEmbeddings

## Choose a configuration to run queries on:
CONFIG = "00"

## Setting up directories for database and table...
DB_DIR = "../db_data/db_"+CONFIG
TABLE = "table_"+CONFIG

## Chose embedding model that matches your config:
MODEL = HuggingFaceEmbeddings(model_name="thenlper/gte-base") ## 00, 02
##MODEL = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") ## 01, 03

## Choose ANN query parameters:
## https://lancedb.github.io/lancedb/concepts/index_ivfpq/#query-the-index
LIMIT=20
NPROBES=20
REFINE_FACTOR=10

## Connect to the database...
db = lancedb.connect(DB_DIR)
## Load the table we've already created..
table = db.open_table(TABLE)

## Setup save path:
SAVE_PATH = "../data_labeling/raw_results/query_results_config_"+CONFIG+".parquet"


In [2]:
def ask_a_query(query_text, print_results=False):
    ## Embed the query
    query = MODEL.embed_query(query_text)

    ## Perform similarity search on database using ANN
    result = table.search(query) \
        .limit(LIMIT) \
        .nprobes(NPROBES) \
        .refine_factor(REFINE_FACTOR) \
        .to_polars()

    if print_results:
        ## Print the basics components of the result...
        pl.Config.set_tbl_rows(25)
        pl.Config.set_fmt_str_lengths(3000)
        print(result[["aware_post_type","reddit_subreddit", "reddit_text", "text_chunk", "_distance"]])
        pl.Config.set_tbl_rows(10)
        pl.Config.set_fmt_str_lengths(20)
        
    return result

In [3]:
standard_queries =[
    "How do General Motors employees feel about RTO?",
    "What kind of benefits does GM offer?",
    "When should you apply for a promotion at GM?",
    "How much does a driver make with UPS?",
    "How long is a typical UPS shift? OR Should I work a double shift at UPS?",
    "How do UPS employees feel about route cuts?",
    "Is it better to work at fedex express or fedex ground?",
    "How do FedEx employees feel about route cuts?",
    "How often do you get a raise at Lowes?",
    "Does your schedule get changed often at Lowes?",
    "What is the worst drink to make for Starbucks baristas?",
    "Does Starbucks pay overtime?",
    "What is your favorite thing about working for Starbucks?",
    "How do Whole Foods workers feel about store managers?",
    "What job perks for Whole Foods employees value most?",
    "Do Kraken employees see themselves staying at the company for the long term?",
    "What do Kraken employees find frustrating in their day to day work?",
    "What benefits do Chase employees value most?",
    "Do Chase employees see opportunities for promotion and professional growth at the company?",
    "What causes bank employees the most stress at work?",
    "What are some reasons that bank employees quit their jobs?",
    "Do Fidelity employees want to work remotely?",
    "Do GameStop employees feel valued by the company?",
    "What does a typical day look like when working for GameStop?",
    "Do CVS employees feel safe at work?",
    "What do CVS workers do if they notice theft?"]

def generate_query_results_for_labeling(standard_queries=standard_queries):
    ## Initialize a results dataframe
    results = ask_a_query(standard_queries[0])

    ## Add 'query_text' column to results dataframe
    results = results.with_columns(
        query_text = pl.lit(standard_queries[0]))

    ## Concatenate the results of the reamining queries
    for i in range(1,len(standard_queries)):
        next = ask_a_query(standard_queries[i])
        next = next.with_columns(
            query_text = pl.lit(standard_queries[i]))
        results = pl.concat([results, next])

    ## Add columns for our labels
    results = results.with_columns(
        mo_label = pl.lit(None),
        kk_label = pl.lit(None),
        kp_label = pl.lit(None),
        dr_label = pl.lit(None),
        sr_label = pl.lit(None)
    )

    return results

In [5]:

## Check if the file already exists.
if os.path.exists(SAVE_PATH):
  ## If it does, do nothing. 
  print("WARNING: The file", file_path, "already exists. No results saved.")
else:
  ## If not, write it.
  results = generate_query_results_for_labeling()
  results.write_parquet(SAVE_PATH)