# Improving RAG by Averaging

#### Authors: Gilyoung Cheong, Qidu Fu, Junichi Koganemaru, Xinyuan Lai, Sixuan Lou, Dapeng Shang
 
This notebook is written as a part of the cap stone project for the [Erdős Institute Data Science Boot Camp](https://www.erdosinstitute.org/). The data used in this notebook is provided by Jason Morgan at AwareHQ. We use Gemma 2B-IT using HuggingFace API, which we learned from [this article](https://huggingface.co/learn/cookbook/en/rag_with_hugging_face_gemma_mongodb) by Richmond Alake.

In this notebook, we implement some pipelines of [Retrieval-Augmented Generation (RAG)](https://aws.amazon.com/what-is/retrieval-augmented-generation/) using [SBERT](https://arxiv.org/abs/1908.10084) developed by Nils Reimers and Iryna Gurevych. The documentation for the SBERT API for Python is available in [this link](https://sbert.net/). We use SBERT to find relevant comments to a query about Walmart employees from the [Walmart Employees subreddit](https://www.reddit.com/r/WalmartEmployees/); we only use 10400 comments from previously saved data.

The pretrained SBERT converts any sentence into a vector in $\mathbb{R}^{1024}$, and the relevance of the two sentences is simply measured by the cosine similarity of the corresponding vectors. That is, if $\boldsymbol{u}$ and $\boldsymbol{v}$ are the vectors, we measure 

$$\frac{\langle \boldsymbol{u}, \boldsymbol{v} \rangle}{\|\boldsymbol{u}\|\|\boldsymbol{v}\|},$$

which can be intuitively thought as $\cos(\theta)$, where $\theta$ is the angle between $\boldsymbol{u}$ and $\boldsymbol{v}$.

## Benefits of SBERT vs BERT

SBERT (Sentence Bert) is based on [BERT (Bidirectional Encoder Representations from Transformer)](https://arxiv.org/abs/1810.04805) developed by Google. From inspection, there are clear benefits of using SBERT over BERT for our purpose.

1. BERT is designed to generate vectors that correspond to individual words (or more precisely *subwords*) to a sentence, so each sentence is converted into not just a vector but a sequence of vectors. Hence, in order to examine the similaritiy of two sentences, we need to either pick one word or take the average of the vectors, which did not yield satisfying results.

2. Because BERT converts every subword as a vector, in order to fully use it, we need to use a lot more storage. For our purpose of examining 10400 comments, it required 11.8GB with BERT while it only required 91.6MB with SBERT.

3. For BERT, the query and the comments (i.e., information to answer the query) need to be proceeded together when we embedd them as (sequences of) vectors. For SBERT, we can vectorize the comments first and then indepedently vectorize the query later.

## Naive RAG vs Not-so-naive RAG

* The naive RAG for us means that we find top 5 relevant comments to the query and use them to generate a response to the query using LLM.
* For the not-so-niave RAG, we let our LLM generate more similar queries to the original query and re-rank the comments by the average cosine similarity (i.e., averaging the cosine similaities of each comment to all the possible queries). Then we use the top 5 comments to generate a response to the query using LLM.

We shall see that both LLM generated responses are quite satisfying and difficult to evaluate. Hence, we shall only evaluate the retrival process (i.e., ranking comments).

In [70]:
# Import necessary libraries (more to be added later)
from transformers import AutoTokenizer, AutoModelForCausalLM # HuggingFace API to use Gemma (LLM)
import torch
import random
from sentence_transformers import SentenceTransformer # Sentence BERT
import pyarrow.parquet as pq
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

In [71]:
# Set a random seed
random_seed = 42
random.seed(random_seed)

## 1. Loading and cleaning the data

In [72]:
path = "data\\reddit.parquet"
df = pd.read_parquet(path, engine='fastparquet')

In [73]:
df.sample(5)

Unnamed: 0,aware_post_type,aware_created_ts,reddit_id,reddit_name,reddit_created_utc,reddit_author,reddit_text,reddit_permalink,reddit_title,reddit_url,reddit_subreddit,reddit_link_id,reddit_parent_id,reddit_submission
4242968,comment,2024-01-02T22:01:23,kg2p1im,t1_kg2p1im,1704251000.0,Hold-My-Shuriken,I’m going straight for the tires and a fucking...,/r/UPSers/comments/18x44dl/what_do_i_do/kg2p1im/,,,UPSers,t3_18x44dl,t1_kg2a8x8,18x44dl
959158,comment,2023-05-28T10:27:18,jly3wiu,t1_jly3wiu,1685284000.0,NeitherCapital1541,Dudes straight tiggered lmao,/r/McDonaldsEmployees/comments/13tph1m/fuck_cu...,,,McDonaldsEmployees,t3_13tph1m,t1_jlxx9wo,13tph1m
3441407,comment,2023-11-08T23:04:15,k8giunb,t1_k8giunb,1699503000.0,celestee3,I was so sad about the oat fudge bar 😭,/r/starbucks/comments/17r03vw/whats_one_menu_i...,,,starbucks,t3_17r03vw,t1_k8ft2il,17r03vw
2308810,comment,2023-08-31T09:27:18,jyimlwn,t1_jyimlwn,1693488000.0,x_scion_x,"I want to say yes, but I like having job secur...",/r/sysadmin/comments/165x63z/should_all_employ...,,,sysadmin,t3_165x63z,t3_165x63z,165x63z
4916766,comment,2024-02-02T15:10:06,komx5ra,t1_komx5ra,1706905000.0,mootmahsn,We use Epic Securechat where I work. Not attac...,/r/nursing/comments/1ah85y2/mds_policing_rns_c...,,,nursing,t3_1ah85y2,t1_komboo1,1ah85y2


In [74]:
df.columns

Index(['aware_post_type', 'aware_created_ts', 'reddit_id', 'reddit_name',
       'reddit_created_utc', 'reddit_author', 'reddit_text',
       'reddit_permalink', 'reddit_title', 'reddit_url', 'reddit_subreddit',
       'reddit_link_id', 'reddit_parent_id', 'reddit_submission'],
      dtype='object')

In [75]:
df = df[~(df['reddit_text'] == '')] # erasing empty reddit texts
df = df[~(df['reddit_text']=='[removed]')] # erasing removed reddit texts
df = df[~(df['reddit_text']=='[deleted]')] # erasing deleted reddit texts
df = df.sort_values(by='reddit_text') # sort them by reddit texts
df = df.reset_index().drop(columns='index') # resetting indices

In [76]:
df.sample(5)

Unnamed: 0,aware_post_type,aware_created_ts,reddit_id,reddit_name,reddit_created_utc,reddit_author,reddit_text,reddit_permalink,reddit_title,reddit_url,reddit_subreddit,reddit_link_id,reddit_parent_id,reddit_submission
1617747,comment,2023-12-08T12:19:15,kciusgq,t1_kciusgq,1702056000.0,cosmicquakingmess,I know what you mean. It makes me cringe so ha...,/r/nursing/comments/18dpc5m/new_workplace_is_f...,,,nursing,t3_18dpc5m,t1_kcifurt,18dpc5m
1645767,comment,2023-06-16T02:58:23,jobmw1q,t1_jobmw1q,1686899000.0,lifetakesguts,I love my job. I work in the ED and it gets ef...,/r/nursing/comments/14ak8ef/does_everyone_hate...,,,nursing,t3_14ak8ef,t3_14ak8ef,14ak8ef
1064356,comment,2023-08-02T03:35:40,jugd38b,t1_jugd38b,1690962000.0,Siran_Amaya,HAHAHA you think you can sue a private busines...,/r/walmart/comments/15fyzlp/attention_customer...,,,walmart,t3_15fyzlp,t1_jugbeoc,15fyzlp
1947405,submission,2023-04-06T23:23:42,12e8478,t3_12e8478,1680838000.0,cest_rien,I would really like to link my brokerage accou...,/r/fidelityinvestments/comments/12e8478/when_c...,When can we get cash manager for Fidelity Bloo...,https://www.reddit.com/r/fidelityinvestments/c...,fidelityinvestments,,,
3025094,comment,2023-11-11T23:57:55,k8w3p7p,t1_k8w3p7p,1699765000.0,D1g1talF00tpr1nt,Not worth it imo given we're on the brink of WW3,/r/cybersecurity/comments/17taggj/going_into_t...,,,cybersecurity,t3_17taggj,t3_17taggj,17taggj


We focus on 'WalmartEmployees' subreddit for our initial implement.

In [77]:
df_warmart = df[df['reddit_subreddit']=='WalmartEmployees']
df_warmart.info()

<class 'pandas.core.frame.DataFrame'>
Index: 10405 entries, 78 to 5449063
Data columns (total 14 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   aware_post_type     10405 non-null  object 
 1   aware_created_ts    10405 non-null  object 
 2   reddit_id           10405 non-null  object 
 3   reddit_name         10405 non-null  object 
 4   reddit_created_utc  10405 non-null  float64
 5   reddit_author       10405 non-null  object 
 6   reddit_text         10405 non-null  object 
 7   reddit_permalink    10405 non-null  object 
 8   reddit_title        1569 non-null   object 
 9   reddit_url          1569 non-null   object 
 10  reddit_subreddit    10405 non-null  object 
 11  reddit_link_id      8836 non-null   object 
 12  reddit_parent_id    8836 non-null   object 
 13  reddit_submission   8836 non-null   object 
dtypes: float64(1), object(13)
memory usage: 1.2+ MB


In [78]:
df_warmart = df_warmart.sort_values(by='reddit_text') # sort them by reddit texts
df_warmart = df_warmart.reset_index().drop(columns='index') # resetting indices

In [79]:
df_warmart[df_warmart["reddit_text"].apply(len) < 2] # reddit texts with length 1

Unnamed: 0,aware_post_type,aware_created_ts,reddit_id,reddit_name,reddit_created_utc,reddit_author,reddit_text,reddit_permalink,reddit_title,reddit_url,reddit_subreddit,reddit_link_id,reddit_parent_id,reddit_submission
179,comment,2023-04-14T19:54:30,jgaphpq,t1_jgaphpq,1681516000.0,throwawaywalmart117,5,/r/WalmartEmployees/comments/12mhu2m/do_you_ge...,,,WalmartEmployees,t3_12mhu2m,t3_12mhu2m,12mhu2m
180,comment,2023-04-15T13:53:28,jgdt7nh,t1_jgdt7nh,1681581000.0,TheMr91071,5,/r/WalmartEmployees/comments/12mhu2m/do_you_ge...,,,WalmartEmployees,t3_12mhu2m,t3_12mhu2m,12mhu2m
233,comment,2023-11-26T23:54:38,kaxrhv7,t1_kaxrhv7,1701061000.0,IronCityMMA,?,/r/WalmartEmployees/comments/184t7hl/never_doi...,,,WalmartEmployees,t3_184t7hl,t3_184t7hl,184t7hl
234,comment,2023-04-27T09:30:24,jhx0uuw,t1_jhx0uuw,1682602000.0,abbymarie67,?,/r/WalmartEmployees/comments/1303l7m/anyone_el...,,,WalmartEmployees,t3_1303l7m,t1_jhwd5zs,1303l7m
4920,comment,2023-02-26T00:16:18,ja1u21i,t1_ja1u21i,1677389000.0,Apprehensive-Ad-8858,K,/r/WalmartEmployees/comments/11c5fcb/tax_seaso...,,,WalmartEmployees,t3_11c5fcb,t3_11c5fcb,11c5fcb
4934,comment,2023-11-15T16:32:06,k9ewstd,t1_k9ewstd,1700084000.0,Accomplished-Ad-482,L,/r/WalmartEmployees/comments/17w1379/keep_your...,,,WalmartEmployees,t3_17w1379,t1_k9elybq,17w1379
10360,comment,2021-06-11T12:20:13,h1f2ir0,t1_h1f2ir0,1623428000.0,cosmic_bb_v,🏆,/r/WalmartEmployees/comments/njc5og/how_come_w...,,,WalmartEmployees,t3_njc5og,t1_h1653g9,njc5og
10361,comment,2022-01-29T20:23:02,husl5go,t1_husl5go,1643506000.0,naen77,🐀,/r/WalmartEmployees/comments/sfanf1/should_i_s...,,,WalmartEmployees,t3_sfanf1,t3_sfanf1,sfanf1
10363,comment,2023-07-03T19:25:41,jqkhcdq,t1_jqkhcdq,1688427000.0,Previous-Sun-4462,😀,/r/WalmartEmployees/comments/14pwho1/current_m...,,,WalmartEmployees,t3_14pwho1,t3_14pwho1,14pwho1
10364,comment,2022-12-20T00:46:20,j0xwudq,t1_j0xwudq,1671515000.0,PrettyGirlChaz431,😂,/r/WalmartEmployees/comments/zqdf32/many_open_...,,,WalmartEmployees,t3_zqdf32,t1_j0xvazj,zqdf32


In [80]:
df_warmart = df_warmart.drop(index = [10363, 10364, 10391, 10392, 10402]) # We drop some not-so-informative comments just so that we get a multiple of 200 comments

In [81]:
df_warmart = df_warmart.sort_values(by='reddit_text') # sort them by reddit texts
df_warmart = df_warmart.reset_index().drop(columns='index') # resetting indices
df_warmart

Unnamed: 0,aware_post_type,aware_created_ts,reddit_id,reddit_name,reddit_created_utc,reddit_author,reddit_text,reddit_permalink,reddit_title,reddit_url,reddit_subreddit,reddit_link_id,reddit_parent_id,reddit_submission
0,submission,2022-07-18T04:42:29,w1ucbr,t3_w1ucbr,1.658134e+09,Relative_Dream1659,\n\n\nI've recently began working at Walmart (...,/r/WalmartEmployees/comments/w1ucbr/unsure_of_...,unsure of what to do....,https://www.reddit.com/r/WalmartEmployees/comm...,WalmartEmployees,,,
1,comment,2023-08-31T02:55:29,jyhl1r5,t1_jyhl1r5,1.693465e+09,Broad_Payment1153,\n\n\nhttp://www.loansforfeds.com/je02131-sms-...,/r/WalmartEmployees/comments/165nas8/pto_and_p...,,,WalmartEmployees,t3_165nas8,t3_165nas8,165nas8
2,submission,2022-03-16T22:54:31,tg08ht,t3_tg08ht,1.647486e+09,Complete_Flow7713,\n\nMy bank account was compromised so I was t...,/r/WalmartEmployees/comments/tg08ht/direct_dep...,Direct deposit,https://www.reddit.com/r/WalmartEmployees/comm...,WalmartEmployees,,,
3,comment,2023-08-31T02:56:28,jyhl4jr,t1_jyhl4jr,1.693465e+09,Broad_Payment1153,\n\nhttp://www.loansforfeds.com/je02131-sms-1 ...,/r/WalmartEmployees/comments/165nowh/finally_l...,,,WalmartEmployees,t3_165nowh,t3_165nowh,165nowh
4,submission,2023-08-31T02:54:30,1663mai,t3_1663mai,1.693465e+09,Broad_Payment1153,\n\nhttp://www.loansforfeds.com/je02131-sms-1 ...,/r/WalmartEmployees/comments/1663mai/work_for_...,Work for Walmart,https://www.reddit.com/r/WalmartEmployees/comm...,WalmartEmployees,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10395,comment,2024-01-17T07:27:27,ki9qte2,t1_ki9qte2,1.705494e+09,AlpineLad1965,"🤦‍♂️ dang, I forgot to pay attention",/r/WalmartEmployees/comments/198p53b/why_do_yo...,,,WalmartEmployees,t3_198p53b,t1_ki9optn,198p53b
10396,comment,2023-09-04T08:05:49,jz31zx8,t1_jz31zx8,1.693829e+09,Electrical-Boss-3965,🤦🏿‍♂️,/r/WalmartEmployees/comments/169jr7g/i_survive...,,,WalmartEmployees,t3_169jr7g,t1_jz2zbzg,169jr7g
10397,comment,2022-03-28T09:29:55,i2fofcj,t1_i2fofcj,1.648474e+09,,🤭 it was worth a try,/r/WalmartEmployees/comments/tpjeuy/what_does_...,,,WalmartEmployees,t3_tpjeuy,t1_i2fmz55,tpjeuy
10398,comment,2023-11-26T22:14:52,kaxf9qp,t1_kaxf9qp,1.701055e+09,Sudden_Swim8998,🤷‍♂️ life happens. When i first started workin...,/r/WalmartEmployees/comments/184l6ww/will_i_be...,,,WalmartEmployees,t3_184l6ww,t1_kaxeeyp,184l6ww


## 2. Using SBERT to rank the reddit comments

### $\S2.1$ Partitioning the dataframe into 200-size batches

This partitioning is for the sake of efficient saving so that we do not need to run into memory problems. This part is more important when we use BERT instead of SBERT, where it is impossible to compute anything using local memory.

In [82]:
# Input: a list
# Output: Partitions each part of which is size 200; the last part may be < 200

def partition_200(list):
    L = []
    cuts = [200 * n for n in range(round(len(list)/200))]

    for c in cuts:
        L.append(list[c : c+200])
    L.append(list[c+200: len(list)])
    return L

# We don't use the following, but we keep it for possible future use

def partition_d(list, d):
    L = [] # d is the size of each partition
    cuts = [d * n for n in range(round(len(list)/d))]

    for c in cuts:
        L.append(list[c : c+d])
    L.append(list[c+d: len(list)])
    return L

In [83]:
text = df_warmart['reddit_text'].to_list()
text_partition = partition_200(text)
len(text)

10400

In [84]:
lengths = list(map(len, text_partition))
print(sum(lengths)) # Checking if the partition was done right

10400


### $\S2.2$. Using SBERT to convert the Reddit comments into vectors

In [85]:
# https://huggingface.co/thenlper/gte-large
sentence_model = SentenceTransformer("thenlper/gte-large")

def get_sentence_embedding(text):
    if not text.strip(): # .strip() gets rid of new lines
        print("Attempted to get embedding for empty text.")
        return []

    embedding = sentence_model.encode(text)

    return embedding.tolist()

In [86]:
# Set a random seed for PyTorch (for GPU as well)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

In [87]:
# The following is to save senetence vectors -- this does not need to be repeated once done

# with torch.no_grad():

#     for n in range(len(text_partition)):
#         outputs = []
#         for text in text_partition[n]:
#             outputs.append(get_sentence_embedding(text))
            
#         torch.save(outputs, f's_walmart_{200* n}_{200 * (n+1)-1}.pt')

# s_walmart_{10400}_{10599}.pt is empty, so just erase

In [88]:
text_partition[-1] # we got one more empty cell -- we shall just ignore this cell

[]

In [89]:
# Loading the sentence vectors in 200-batches

sentence_vector_partitions = []
for n in range(len(text_partition)-1):
    loaded_outputs = torch.load(f'data\\s_walmart_{200 * n}_{200 * (n+1)-1}.pt')
    sentence_vector_partitions.append(loaded_outputs)

In [90]:
print(len(sentence_vector_partitions)) # 52 * 200 = 10400
print(len(sentence_vector_partitions[-1])) # checking if the last cell is nonempty (of size 200)

52
200


In [91]:
sentence_vectors = []
for partition in sentence_vector_partitions:
    sentence_vectors += partition

print(len(sentence_vectors)) # sentence_vectors now have all 10400 vectors corresponding to the reddit comments

10400


Following are the functions for measuing cosine similarity and inner product of given two vectors. For SBERT, we inspected that the use of cosine simlarity and the use of inner product yield almost identical results, so we use cosine similarity for the rest of the notebook.

In [92]:
def cos_angle(v, w): # inputs v, w are vectors in 1-dimensional tensors (np.array(list))
    v = v.reshape(1,-1)
    w = w.reshape(1,-1)
    return cosine_similarity(v,w)

def inn(v, w):
    return v @ w

We will work with the following query: "How many PTOs does a regular employee have a year?"

In [93]:
query = "How many PTOs does a regular employee have a year?"
query_vector = get_sentence_embedding(query)
query_vector = np.array(query_vector)

### $\S2.1$ Ranking relevance of the comments to the query by cosine similarities

In [94]:
cos_angles = []
for i in range(len(sentence_vectors)):
    sentence_vector = np.array(sentence_vectors[i])
    cos_angles.append(cos_angle(query_vector, sentence_vector)[0][0])

cos_angles # We shall keep the order of indices to keep track of the vectors

[0.7531367665089905,
 0.7211716846120374,
 0.7198770863199504,
 0.7211716846120374,
 0.7211716846120374,
 0.765540738657215,
 0.773639857781653,
 0.7732377222531612,
 0.7373846945795265,
 0.744867122401287,
 0.7352134998227425,
 0.7655847651499748,
 0.7297354342282854,
 0.7701711327293881,
 0.7254842409253628,
 0.7900648120122646,
 0.7760302958381523,
 0.7442008203808539,
 0.7641805460486946,
 0.7405863364288567,
 0.7678929975497437,
 0.7151102575932938,
 0.7197853858370853,
 0.7852149102024455,
 0.7873830403495579,
 0.824054193615407,
 0.7249708433391602,
 0.8007364328828361,
 0.7290406551280623,
 0.7580510732776432,
 0.7498251116630137,
 0.7643833605835479,
 0.7417099942360259,
 0.7850464835408572,
 0.7673653915127828,
 0.7497045151674337,
 0.6540932479799919,
 0.7509070260994909,
 0.7509070260994909,
 0.7509070260994909,
 0.7509070260994909,
 0.7604918170103101,
 0.7696952156114458,
 0.7582802186785927,
 0.7612537570221423,
 0.7319905870608392,
 0.7556282864958137,
 0.76851621348474

We rank indices of the 10400 comments based on the cosine similarity:

In [95]:
cos_ranked_indices = np.array(cos_angles).argsort() 
cos_ranked_indices = cos_ranked_indices.tolist()
cos_ranked_indices.reverse() # This is so that we get descending order
cos_ranked_indices

[357,
 5751,
 4911,
 1685,
 2682,
 8993,
 5728,
 10163,
 2288,
 9757,
 5074,
 120,
 6232,
 77,
 3834,
 6742,
 6075,
 5859,
 3377,
 8189,
 9465,
 5114,
 6303,
 4039,
 7396,
 5781,
 8782,
 7367,
 425,
 4561,
 4510,
 8847,
 2841,
 6775,
 9278,
 4155,
 103,
 6180,
 7747,
 5980,
 8810,
 9247,
 9751,
 1915,
 1488,
 994,
 10084,
 2843,
 5175,
 8017,
 6305,
 7728,
 7889,
 8385,
 7603,
 1618,
 8145,
 5378,
 6358,
 2433,
 4079,
 5687,
 9466,
 5075,
 6371,
 5498,
 5431,
 9671,
 6261,
 5306,
 6363,
 2510,
 3997,
 2626,
 1406,
 2282,
 7964,
 9330,
 185,
 525,
 749,
 2702,
 7773,
 6306,
 8843,
 5803,
 2031,
 2681,
 4814,
 8178,
 8582,
 171,
 5693,
 9281,
 9943,
 6182,
 3043,
 2983,
 9376,
 7560,
 5486,
 1545,
 1169,
 552,
 6389,
 9559,
 735,
 8164,
 5354,
 9250,
 3679,
 5546,
 2891,
 6698,
 2883,
 10325,
 6321,
 566,
 4644,
 8171,
 6107,
 6262,
 8730,
 9311,
 4056,
 1631,
 556,
 6260,
 9555,
 4545,
 2442,
 911,
 8408,
 6556,
 4089,
 2666,
 2976,
 3775,
 4447,
 9361,
 9002,
 4065,
 4775,
 1269,
 4038

We check that the indices are ranked as intended:

In [96]:
for i in cos_ranked_indices:
    print(cos_angles[i])

0.8828317893774706
0.8819889507027516
0.8800527399006086
0.8785763826143662
0.8777285020681045
0.8763098560143696
0.8746375849750487
0.8737874171549231
0.8722765695960524
0.8669434036030312
0.8662125692195775
0.8647107238354834
0.8644504776074813
0.8625648399281878
0.8624278461252207
0.8623923749727399
0.8613801516498696
0.8580153549963417
0.8575977160816813
0.8572187000384022
0.8565265523017254
0.8564028293924406
0.8563546552779786
0.85633629256666
0.8560131125756054
0.8554775076352336
0.8552294124968952
0.8549767104781456
0.8538550547240407
0.8534944851616155
0.8532461015970272
0.852901705472458
0.8528979795501725
0.8526670125177468
0.8525098201317941
0.8524493910727307
0.8522997664196033
0.8519692823414928
0.8513179706324792
0.851084687681098
0.849893278677031
0.8490534409227255
0.849042474619869
0.8488116349030427
0.8483316313135805
0.847336236878401
0.8467510950122691
0.8467333312970675
0.8467265184631496
0.8463207356132372
0.8456241406426899
0.8444706016096898
0.8440467509822414


In [97]:
# Printing the comments ranked by cosine similarities

print("Query: ", query)

for n, i in enumerate(cos_ranked_indices):
    print(f"{n+1}:", df_warmart["reddit_text"].iloc[i])

Query:  How many PTOs does a regular employee have a year?
1: All associates earn PPTO, which is intended for emergencies (car won't start, you're sick, etc)

Full-time associates also earn PTO, which is intended for scheduled absences (doctor appointments, getting your drivers license renewed, vacations, etc). **Part-time associates do not earn PTO until they have been with the company for 3 years,** because part-time associates should be able to schedule their appointments and other errands on their days off.
2: Not sure about PTO, but the most ppto you can earn in a year is 48 hours, with the exception of a couple of states which are unlimited by state law.
3: Just went through onboarding (for Walmart distribution) and my chart clearly has 1 PPTO/30 hrs worked regardless of tenure. Regular pto is closer to how you described. YMMV
4: Full timers generally earn pto faster. But most part timers actually earn ppto faster. For most associates in most states, ppto is 1 hour for every 30 h

### $\S2.3$ Producing alternative queries

We generate alternative queries pregenerated by an LLM that are similar to our original query: "How many PTOs does a regular employee have a year?" In this note we use Gemma 2B-IT through HuggingFace API.

**Warning**. To reproduce the following, one may need HuggingFace API Key (which is free for the purpose of this notebook).

In [98]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
# CPU Enabled uncomment below
# model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
# GPU Enabled use below
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto")

Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.21s/it]


We write a function that produces a response from our LLM given a prompt:

In [99]:
def llm(prompt):
    input_ids = tokenizer(prompt, return_tensors="pt")
    response = model.generate(**input_ids, max_new_tokens=512)
    return tokenizer.decode(response[0])

In [100]:
prompt = "Generate 10 similar queries to the following: \"How many PTOs does a regular employee have a year?\" \n\n"
llm_response = llm(prompt)
print(llm_response)

<bos>Generate 10 similar queries to the following: "How many PTOs does a regular employee have a year?" 

1. How many sick days does a regular employee have a year?
2. How many vacation days does a regular employee have a year?
3. How many personal days does a regular employee have a year?
4. How many sick leave days does a regular employee have a year?
5. How many vacation leave days does a regular employee have a year?
6. How many personal leave days does a regular employee have a year?
7. How many sick days does a part-time employee have a year?
8. How many sick days does a full-time employee have a year?
9. How many sick days does a part-time employee have a year?
10. How many sick days does a full-time employee have a year?<eos>


In [101]:
stopping_places = []

for i in range(10):
    num = str(i+1)
    stopping_places.append(llm_response.find(num + "."))

stopping_places

[106, 165, 228, 291, 356, 425, 494, 555, 616, 677]

We store the alternative queries as follows:

In [102]:
stopping_places.append(len(llm_response))
alt_queries = []

for i in range(len(stopping_places)-1):
    start = stopping_places[i]
    end = stopping_places[i+1]
    alt_queries.append(llm_response[start:end])

alt_queries

['1. How many sick days does a regular employee have a year?\n',
 '2. How many vacation days does a regular employee have a year?\n',
 '3. How many personal days does a regular employee have a year?\n',
 '4. How many sick leave days does a regular employee have a year?\n',
 '5. How many vacation leave days does a regular employee have a year?\n',
 '6. How many personal leave days does a regular employee have a year?\n',
 '7. How many sick days does a part-time employee have a year?\n',
 '8. How many sick days does a full-time employee have a year?\n',
 '9. How many sick days does a part-time employee have a year?\n',
 '10. How many sick days does a full-time employee have a year?<eos>']

In [103]:
# We vectorize the alternative queries

alt_query_vectors = []
for q in alt_queries:
    alt_query_vectors.append(get_sentence_embedding(q))

In [104]:
# We now compute the cosine similarity between the vector of each alternative query and our original query

alt_query_cos_angles = []
for v in alt_query_vectors: 
    alt_query_cos_angles.append(cos_angle(query_vector, np.array(v)))

alt_query_cos_angles

[array([[0.87719671]]),
 array([[0.88927791]]),
 array([[0.88713579]]),
 array([[0.86620137]]),
 array([[0.8783751]]),
 array([[0.87819074]]),
 array([[0.8630544]]),
 array([[0.84974469]]),
 array([[0.86013588]]),
 array([[0.85545097]])]

In [105]:
# cleaning the format

for i in range(len(alt_query_cos_angles)):
    alt_query_cos_angles[i] = alt_query_cos_angles[i][0][0]

alt_query_cos_angles 

[0.8771967077126346,
 0.8892779083672098,
 0.887135786329719,
 0.8662013683356176,
 0.8783751047809851,
 0.8781907437340979,
 0.8630544036724722,
 0.8497446943358696,
 0.8601358811363096,
 0.8554509723857592]

We only take the ones with cos similaity greater than 0.87:

In [106]:
indices_for_queries = [i for i in range(len(alt_query_cos_angles)) if alt_query_cos_angles[i] > 0.87]
indices_for_queries

[0, 1, 2, 4, 5]

In [107]:
chosen_queries = []
for i in indices_for_queries:
    chosen_queries.append(alt_queries[i])
chosen_queries

['1. How many sick days does a regular employee have a year?\n',
 '2. How many vacation days does a regular employee have a year?\n',
 '3. How many personal days does a regular employee have a year?\n',
 '5. How many vacation leave days does a regular employee have a year?\n',
 '6. How many personal leave days does a regular employee have a year?\n']

In [108]:
chosen_query_vectors = []
for i in indices_for_queries:
    chosen_query_vectors.append(alt_query_vectors[i])

### $\S2.4$ Re-ranking relevance of the comments to the average cosine similarities

For each comment, we take the average of the cosine simlarities between the comment and all the queries

In [109]:
chosen_query_vectors.append(query_vector)

avg_cos_angles = []
for i in range(len(sentence_vectors)):
    sentence_vector = np.array(sentence_vectors[i])
    avg_cos_angle = 0
    for v in chosen_query_vectors:
        v = np.array(v)
        avg_cos_angle += cos_angle(v, sentence_vector)[0][0]
    avg_cos_angles.append(avg_cos_angle/len(chosen_query_vectors))

In [110]:
avg_cos_ranked_indices = np.array(avg_cos_angles).argsort() 
avg_cos_ranked_indices = avg_cos_ranked_indices.tolist()
avg_cos_ranked_indices.reverse() # This is so that we get descending order
avg_cos_ranked_indices

[9465,
 4545,
 823,
 120,
 7396,
 305,
 8847,
 9278,
 10241,
 4622,
 2288,
 5394,
 7728,
 6318,
 3971,
 3496,
 2682,
 6232,
 1618,
 5728,
 1488,
 171,
 2031,
 4379,
 77,
 192,
 9454,
 1621,
 3097,
 8681,
 1542,
 8993,
 8674,
 1348,
 6508,
 5751,
 2690,
 9396,
 115,
 357,
 9309,
 8932,
 439,
 2563,
 8814,
 2723,
 1343,
 3225,
 483,
 5959,
 167,
 1278,
 4155,
 2174,
 8843,
 2442,
 10163,
 6358,
 5693,
 1631,
 9379,
 6233,
 1516,
 4079,
 1915,
 7796,
 1406,
 7931,
 7603,
 7883,
 7948,
 1722,
 8730,
 4783,
 3377,
 4565,
 1214,
 1152,
 8582,
 4911,
 7548,
 1686,
 4454,
 9466,
 1683,
 5741,
 749,
 552,
 244,
 436,
 9004,
 173,
 5743,
 4944,
 4071,
 3268,
 5378,
 9758,
 560,
 6316,
 4390,
 163,
 2841,
 6107,
 1685,
 7747,
 486,
 3274,
 2510,
 7487,
 3921,
 667,
 6471,
 6389,
 7964,
 1620,
 3055,
 9846,
 3679,
 5390,
 8810,
 4815,
 360,
 7242,
 9991,
 1500,
 9376,
 5980,
 5585,
 8385,
 10015,
 7701,
 1011,
 1224,
 2681,
 76,
 9789,
 10320,
 4137,
 2976,
 8093,
 395,
 8189,
 1614,
 6065,
 9757,

In [111]:
# Printing the comments ranked by average cosine similarities

print("Query: ", query)

for n, i in enumerate(avg_cos_ranked_indices):
    print(f"{n+1}:", df_warmart["reddit_text"].iloc[i])

Query:  How many PTOs does a regular employee have a year?
1: You get protected pto(like sick time) from day 1 but you can't use it until day 90. Regular pto(like vacation time) you start getting at day 90.
2: It’s a business and it’s the holidays. They need all the help they can get. Since you were hired a month ago it’s very unlikely they will give you those days off. X eve and New Year’s Day at my store are two point days. So 4 total
3: Bullshit we have people at the store I work at take 2 to 3 weeks off all the time
4: 1.00 = one hour.

So you have almost an hour of ppto.


You have 68 hours of PTO.
 Dividend 68 by 8 (assuming youre covering an 8 hour work day)

You have 8 days you can take off fully, with some change.


So in total(assuming your PTO gets approved), you can take 9 days off in a row(as you accrued ppto when using PTO if I remember correctly).
5: That’s assuming they’re full time, or at least in their 3rd year. Otherwise it’s 1 for every 43.33 hours worked
6: Actuall

## 3. Feeding the ranked and re-ranked sentences to LLM

We introduce a function to generate a prompt for our LLM from top $k$ comments from ranked (or re-ranked comments):

In [112]:
def rag_prompt(indices, k):
    information_to_feed = ""
    for n, i in zip(range(k), indices):
        information_to_feed += f"{n+1}: " + df_warmart["reddit_text"].iloc[i] + "\n"
    # concatenate the first top k comments
    combined_information = f"\nQuery: {query}\n\nAnswer the above query by only using the following:\n\n{information_to_feed}\n\nLLM Response:"
    
    return combined_information

In [113]:
prompt_ranked = rag_prompt(cos_ranked_indices, 5) # prompt generated from top 5 commenets ranked by cos sim to the original query
prompt_reranked = rag_prompt(avg_cos_ranked_indices, 5) # prompt generated from top 5 commenets reranked by cos sim to the original query + alternative queries

In [114]:
llm_answer1 = llm(prompt_ranked)
llm_answer2 = llm(prompt_reranked)

In [115]:
print("Using top 5 commenets ranked by cos sim to the original query \n\n", llm_answer1)

Using top 5 commenets ranked by cos sim to the original query 

 <bos>
Query: How many PTOs does a regular employee have a year?

Answer the above query by only using the following:

1: All associates earn PPTO, which is intended for emergencies (car won't start, you're sick, etc)

Full-time associates also earn PTO, which is intended for scheduled absences (doctor appointments, getting your drivers license renewed, vacations, etc). **Part-time associates do not earn PTO until they have been with the company for 3 years,** because part-time associates should be able to schedule their appointments and other errands on their days off.
2: Not sure about PTO, but the most ppto you can earn in a year is 48 hours, with the exception of a couple of states which are unlimited by state law.
3: Just went through onboarding (for Walmart distribution) and my chart clearly has 1 PPTO/30 hrs worked regardless of tenure. Regular pto is closer to how you described. YMMV
4: Full timers generally earn p

In [116]:
print("Using top 5 commenets reranked by cos sim to the original query + alternative queries \n\n", llm_answer2)

Using top 5 commenets reranked by cos sim to the original query + alternative queries 

 <bos>
Query: How many PTOs does a regular employee have a year?

Answer the above query by only using the following:

1: You get protected pto(like sick time) from day 1 but you can't use it until day 90. Regular pto(like vacation time) you start getting at day 90.
2: It’s a business and it’s the holidays. They need all the help they can get. Since you were hired a month ago it’s very unlikely they will give you those days off. X eve and New Year’s Day at my store are two point days. So 4 total
3: Bullshit we have people at the store I work at take 2 to 3 weeks off all the time
4: 1.00 = one hour.

So you have almost an hour of ppto.


You have 68 hours of PTO.
 Dividend 68 by 8 (assuming youre covering an 8 hour work day)

You have 8 days you can take off fully, with some change.


So in total(assuming your PTO gets approved), you can take 9 days off in a row(as you accrued ppto when using PTO if 

## 4. Evaluation of retrieval

Note that it is rather difficult to say which LLM responses are better. Moreover, we note that our goal is NOT to get the answer that is absolutely correct but a relevant one among the reddit comments that we put in. For example, the answer may change over time, unless we update the input comments.

Hence, we use use both of the LLM responses as ground truths and compare the top 50 retrievals from the two methods:
* Method 1: Naive RAG using cosine similairties against the original query
* Method 2: Not-so-naive RAG using average cosine similairties against multiple similar queries, including the original one

In [117]:
truth_1 = "Regular employees are entitled to 1 hour of paid time off per 30 hours worked, with a maximum of 48 hours per year."
truth_2 = "An employee is entitled to 68 hours of paid time off per year."

indices_1 = cos_ranked_indices[:50]
indices_2 = avg_cos_ranked_indices[:50]

vectors_1 = []
vectors_2 = []

for i in indices_1:
    vectors_1.append(sentence_vectors[i])

for i in indices_2:
    vectors_2.append(sentence_vectors[i])

### Evaluation metric 1: cosine precision
The following is a function with which we evaluate the retrieval from each method. Let $\boldsymbol{t}_1$ and $\boldsymbol{t}_2$ be the truth vectors. For each vector $\boldsymbol{v}$ from a batch, the cosine similarities $\cos(\boldsymbol{t}_1, \boldsymbol{v})$ and $\cos(\boldsymbol{t}_2, \boldsymbol{v})$ are in the interval $[-1, 1]$, but in all of our examples, we know they are in $[0, 1]$. We simply take the average of the two to measure how truthful $\boldsymbol{v}$ is. Note that the closer the average is to $1$, the more truthful $\boldsymbol{v}$ is.

Recall the definition of **precision**:
$$\mathrm{Precision} := \frac{\mathrm{Relevant \ retrieved \ instances}}{\mathrm{All \ retrieved \ instances}}.$$

Given a batch $B$, we define the **cosine precision** as follows:

$$\mathrm{Cosine \ Precision \ of } \ B := \frac{1}{2|B|}\sum_{\boldsymbol{v} \in B}  (\cos(\boldsymbol{t}_1, \boldsymbol{v}) + \cos(\boldsymbol{t}_2, \boldsymbol{v}))$$

In [118]:
def cos_precision(batch, t_1, t_2):
    t_1 = np.array(t_1)
    t_2 = np.array(t_2)
    
    sum = 0

    for v in batch:
        v = np.array(v)
        sum += (cos_angle(t_1, v) + cos_angle(t_2, v))
    return sum / (2*len(batch))

In [119]:
t_1 = get_sentence_embedding(truth_1)
t_2 = get_sentence_embedding(truth_2)

In [120]:
cos_precision(vectors_1, t_1, t_2)

array([[0.83449691]])

In [121]:
cos_precision(vectors_2, t_1, t_2)

array([[0.85419629]])

Indeed, we do see an improvement in our averaging method from the naive RAG from 0.83449691 to 0.85419629.

### Evaluation metric 2: ranked cosine precision

The following is a function that evaluates not only the retrieval, but also evaluates the ranking for the retrieved contexts.

Assume we retrieved $K$ comments in the context, ranked as $B = (x_1, \ldots, x_K)$.

We call the **precision at rank $m$** the cosine precision for the truncated context $B_m := (x_1, \ldots, x_m)$. And the ranked cosine precision is the average of these precisions.

$$
\text{Ranked Cosine Precision of } B := \frac{1}{K} \sum_{m = 1}^{K} \text{Cosine Precision of } B_m.
$$

Under this measurement, those comments ranked higher in the retrieved context will have a higher impact to the precision.

In [None]:
def cos_rank_precision(batch, t_1, t_2):
    sum = 0

    for m in range(1, len(batch)+1):
        sum += cos_precision(batch[:m], t_1, t_2)

    return sum / len(batch)

In [1]:
cos_rank_precision(vectors_1, t_1, t_2)

array([[0.84330116]])


In [2]:
cos_rank_precision(vectors_2, t_1, t_2)

array([[0.85745651]])


## 4. Conclusion and future directions

As we have seen in the example above, our averaging method improves the overall retrieval better by getting rid of possibly unrelated retrieved data by comparisions with multiple similar queries to the original one. The LLM API we are using took a few minutes to generate 10 similar queries, and we could only use half of them to assure the quality of our result. It is evident that any stronger LLM we use would not only make the process faster, but it would also generate more similar queries that would result in an even better retrieval outcome.