<a href="https://colab.research.google.com/github/dpshang/RAG/blob/main/RAG-Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Improving RAG by Averaging

#### Authors: Gilyoung Cheong, Qidu Fu, Junichi Koganemaru, Xinyuan Lai, Sixuan Lou, Dapeng Shang

This notebook is written as a part of the capstone project for the [Erdős Institute Data Science Boot Camp](https://www.erdosinstitute.org/). The data used in this notebook is provided by Jason Morgan at AwareHQ. We use Gemma 2B-IT using HuggingFace API, which we learned from [this article](https://huggingface.co/learn/cookbook/en/rag_with_hugging_face_gemma_mongodb) by Richmond Alake.

In this notebook, we implement some pipelines of [Retrieval-Augmented Generation (RAG)](https://aws.amazon.com/what-is/retrieval-augmented-generation/) using [SBERT](https://arxiv.org/abs/1908.10084) developed by Nils Reimers and Iryna Gurevych. The documentation for the SBERT API for Python is available in [this link](https://sbert.net/). We use SBERT to find relevant comments to a query from various reddit comments from previously saved data.

The pretrained SBERT converts any sentence into a vector in $\mathbb{R}^{1024}$, and the relevance of the two sentences is simply measured by the cosine similarity of the corresponding vectors. That is, if $\boldsymbol{u}$ and $\boldsymbol{v}$ are the vectors, we measure

$$\frac{\langle \boldsymbol{u}, \boldsymbol{v} \rangle}{\|\boldsymbol{u}\|\|\boldsymbol{v}\|},$$

which can be intuitively thought as $\cos(\theta)$, where $\theta$ is the angle between $\boldsymbol{u}$ and $\boldsymbol{v}$.

## Benefits of SBERT vs BERT

SBERT (Sentence Bert) is based on [BERT (Bidirectional Encoder Representations from Transformer)](https://arxiv.org/abs/1810.04805) developed by Google. From inspection, there are clear benefits of using SBERT over BERT for our purpose.

1. BERT is designed to generate vectors that correspond to individual words (or more precisely, *subwords*) to a sentence, so each sentence is converted into not just a vector but a sequence of vectors. Hence, in order to examine the similaritiy of two sentences, we need to either pick one word or take the average of the vectors, which did not yield satisfying results.

2. Because BERT converts every subword as a vector, in order to fully use it, we need to use a lot more storage. In an experiement, examining 10400 comments required 11.8GB with BERT while it only required 91.6MB with SBERT.

3. For BERT, the query and the comments (i.e., information to answer the query) need to be proceeded together when we embedd them as (sequences of) vectors. For SBERT, we can vectorize the comments first and then indepedently vectorize the query later.

In [7]:
# ! pip install sentence_transformers
# ! pip install accelerate

In [8]:
# Importing necessary libraries
import numpy as np
import pandas as pd
import torch
import re
import nltk
import os.path

from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from tqdm import tqdm
from sentence_transformers import SentenceTransformer # SBERT
from sklearn.metrics.pairwise import cosine_similarity
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForCausalLM # HuggingFace API to use Gemma (LLM)
from sklearn.cluster import KMeans
from sklearn import preprocessing

from warnings import filterwarnings
filterwarnings('ignore')

In [9]:
from google.colab import drive
drive.mount('/drive', force_remount=True)

Mounted at /drive


In [10]:
# Defining model for synthetic query generation
tokenizer = T5Tokenizer.from_pretrained('BeIR/query-gen-msmarco-t5-large-v1')
model = T5ForConditionalGeneration.from_pretrained('BeIR/query-gen-msmarco-t5-large-v1')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


## 1. Loading and cleaning the data

In [11]:
# Reading in reddit data and specifying columns to use
columns = ['reddit_text', 'reddit_subreddit']
df = pd.read_parquet("/drive/My Drive/Colab Notebooks/RAG/reddit.parquet", columns = columns)

In [12]:
# Identifying candidate subreddits to analyze
value_count = df['reddit_subreddit'].value_counts()
value_count

reddit_subreddit
nursing                789499
walmart                630962
sysadmin               557558
starbucks              393597
WaltDisneyWorld        373138
Target                 340401
UPSers                 262483
Disneyland             231981
Lowes                  198805
CVS                    179598
McDonaldsEmployees     174679
cybersecurity          161868
Fedexers               154572
GameStop               137071
starbucksbaristas      132019
fidelityinvestments    129423
Bestbuy                121077
wholefoods              82052
Panera                  79436
DisneyWorld             65549
DollarTree              59745
TjMaxx                  46286
disney                  43954
McLounge                38627
GeneralMotors           37277
TalesFromYourBank       28444
cabincrewcareers        23408
Chase                   16931
KrakenSupport           14533
WalmartEmployees        10752
BestBuyWorkers           5629
RiteAid                  3970
PaneraEmployees        

In [13]:
# For proof of concept, we decided to work with medium sized subreddits first
value_count[ value_count < 100000]

reddit_subreddit
wholefoods           82052
Panera               79436
DisneyWorld          65549
DollarTree           59745
TjMaxx               46286
disney               43954
McLounge             38627
GeneralMotors        37277
TalesFromYourBank    28444
cabincrewcareers     23408
Chase                16931
KrakenSupport        14533
WalmartEmployees     10752
BestBuyWorkers        5629
RiteAid               3970
PaneraEmployees       2694
FedEmployees           280
Name: count, dtype: int64

In [14]:
selected_subreddts = ['TalesFromYourBank', 'cabincrewcareers', 'Chase', 'KrakenSupport', 'WalmartEmployees', 'BestBuyWorkers', 'RiteAid', 'PaneraEmployees', 'FedEmployees']

In [15]:
df = df[df['reddit_subreddit'].isin(selected_subreddts)]
len(df)

106641

In [16]:
# Basic data cleaning by removing rows with empty/removed/deleted texts
def remove_empty(df):
    df = df[~(df['reddit_text'] == '')] # erasing empty reddit texts
    df = df[~(df['reddit_text']=='[removed]')] # erasing removed reddit texts
    df = df[~(df['reddit_text']=='[deleted]')] # erasing deleted reddit texts
    df = df[df['reddit_text'].str.len() > 5] #Only kee ping responses longer than 5 characters
    df = df.sort_values(by='reddit_text') # sort them by reddit texts
    df = df.reset_index().drop(columns='index') # resetting indices
    return df

In [17]:
def remove_url(text):
    pattern = re.compile(r'https?://\S+|www.\.\S+')
    return pattern.sub(r'', text)

In [18]:
# Download necessary NLTK data
nltk.download('wordnet')
nltk.download('stopwords')

# Initialize WordNetLemmatizer
LEMMATIZER = WordNetLemmatizer()

# Load stopwords
STOP_WORDS = set(stopwords.words('english'))

def preprocess_text(text):
    if text is None:
        return None
    text = text.replace('\n', '')
    return text

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [19]:
# from: https://github.com/NeelShah18/emot/blob/master/emot/emo_unicode.py
EMOTICONS = {
    u":‑\)":"Happy face or smiley",
    u":\)":"Happy face or smiley",
    u":-\]":"Happy face or smiley",
    u":\]":"Happy face or smiley",
    u":-3":"Happy face smiley",
    u":3":"Happy face smiley",
    u":->":"Happy face smiley",
    u":>":"Happy face smiley",
    u"8-\)":"Happy face smiley",
    u":o\)":"Happy face smiley",
    u":-\}":"Happy face smiley",
    u":\}":"Happy face smiley",
    u":-\)":"Happy face smiley",
    u":c\)":"Happy face smiley",
    u":\^\)":"Happy face smiley",
    u"=\]":"Happy face smiley",
    u"=\)":"Happy face smiley",
    u":‑D":"Laughing, big grin or laugh with glasses",
    u":D":"Laughing, big grin or laugh with glasses",
    u"8‑D":"Laughing, big grin or laugh with glasses",
    u"8D":"Laughing, big grin or laugh with glasses",
    u"X‑D":"Laughing, big grin or laugh with glasses",
    u"XD":"Laughing, big grin or laugh with glasses",
    u"=D":"Laughing, big grin or laugh with glasses",
    u"=3":"Laughing, big grin or laugh with glasses",
    u"B\^D":"Laughing, big grin or laugh with glasses",
    u":-\)\)":"Very happy",
    u":‑\(":"Frown, sad, andry or pouting",
    u":-\(":"Frown, sad, andry or pouting",
    u":\(":"Frown, sad, andry or pouting",
    u":‑c":"Frown, sad, andry or pouting",
    u":c":"Frown, sad, andry or pouting",
    u":‑<":"Frown, sad, andry or pouting",
    u":<":"Frown, sad, andry or pouting",
    u":‑\[":"Frown, sad, andry or pouting",
    u":\[":"Frown, sad, andry or pouting",
    u":-\|\|":"Frown, sad, andry or pouting",
    u">:\[":"Frown, sad, andry or pouting",
    u":\{":"Frown, sad, andry or pouting",
    u":@":"Frown, sad, andry or pouting",
    u">:\(":"Frown, sad, andry or pouting",
    u":'‑\(":"Crying",
    u":'\(":"Crying",
    u":'‑\)":"Tears of happiness",
    u":'\)":"Tears of happiness",
    u"D‑':":"Horror",
    u"D:<":"Disgust",
    u"D:":"Sadness",
    u"D8":"Great dismay",
    u"D;":"Great dismay",
    u"D=":"Great dismay",
    u"DX":"Great dismay",
    u":‑O":"Surprise",
    u":O":"Surprise",
    u":‑o":"Surprise",
    u":o":"Surprise",
    u":-0":"Shock",
    u"8‑0":"Yawn",
    u">:O":"Yawn",
    u":-\*":"Kiss",
    u":\*":"Kiss",
    u":X":"Kiss",
    u";‑\)":"Wink or smirk",
    u";\)":"Wink or smirk",
    u"\*-\)":"Wink or smirk",
    u"\*\)":"Wink or smirk",
    u";‑\]":"Wink or smirk",
    u";\]":"Wink or smirk",
    u";\^\)":"Wink or smirk",
    u":‑,":"Wink or smirk",
    u";D":"Wink or smirk",
    u":‑P":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u":P":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u"X‑P":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u"XP":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u":‑Þ":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u":Þ":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u":b":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u"d:":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u"=p":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u">:P":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u":‑/":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u":/":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u":-[.]":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u">:[(\\\)]":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u">:/":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u":[(\\\)]":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u"=/":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u"=[(\\\)]":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u":L":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u"=L":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u":S":"Skeptical, annoyed, undecided, uneasy or hesitant",
    u":‑\|":"Straight face",
    u":\|":"Straight face",
    u":$":"Embarrassed or blushing",
    u":‑x":"Sealed lips or wearing braces or tongue-tied",
    u":x":"Sealed lips or wearing braces or tongue-tied",
    u":‑#":"Sealed lips or wearing braces or tongue-tied",
    u":#":"Sealed lips or wearing braces or tongue-tied",
    u":‑&":"Sealed lips or wearing braces or tongue-tied",
    u":&":"Sealed lips or wearing braces or tongue-tied",
    u"O:‑\)":"Angel, saint or innocent",
    u"O:\)":"Angel, saint or innocent",
    u"0:‑3":"Angel, saint or innocent",
    u"0:3":"Angel, saint or innocent",
    u"0:‑\)":"Angel, saint or innocent",
    u"0:\)":"Angel, saint or innocent",
    u":‑b":"Tongue sticking out, cheeky, playful or blowing a raspberry",
    u"0;\^\)":"Angel, saint or innocent",
    u">:‑\)":"Evil or devilish",
    u">:\)":"Evil or devilish",
    u"\}:‑\)":"Evil or devilish",
    u"\}:\)":"Evil or devilish",
    u"3:‑\)":"Evil or devilish",
    u"3:\)":"Evil or devilish",
    u">;\)":"Evil or devilish",
    u"\|;‑\)":"Cool",
    u"\|‑O":"Bored",
    u":‑J":"Tongue-in-cheek",
    u"#‑\)":"Party all night",
    u"%‑\)":"Drunk or confused",
    u"%\)":"Drunk or confused",
    u":-###..":"Being sick",
    u":###..":"Being sick",
    u"<:‑\|":"Dump",
    u"\(>_<\)":"Troubled",
    u"\(>_<\)>":"Troubled",
    u"\(';'\)":"Baby",
    u"\(\^\^>``":"Nervous or Embarrassed or Troubled or Shy or Sweat drop",
    u"\(\^_\^;\)":"Nervous or Embarrassed or Troubled or Shy or Sweat drop",
    u"\(-_-;\)":"Nervous or Embarrassed or Troubled or Shy or Sweat drop",
    u"\(~_~;\) \(・\.・;\)":"Nervous or Embarrassed or Troubled or Shy or Sweat drop",
    u"\(-_-\)zzz":"Sleeping",
    u"\(\^_-\)":"Wink",
    u"\(\(\+_\+\)\)":"Confused",
    u"\(\+o\+\)":"Confused",
    u"\(o\|o\)":"Ultraman",
    u"\^_\^":"Joyful",
    u"\(\^_\^\)/":"Joyful",
    u"\(\^O\^\)／":"Joyful",
    u"\(\^o\^\)／":"Joyful",
    u"\(__\)":"Kowtow as a sign of respect, or dogeza for apology",
    u"_\(\._\.\)_":"Kowtow as a sign of respect, or dogeza for apology",
    u"<\(_ _\)>":"Kowtow as a sign of respect, or dogeza for apology",
    u"<m\(__\)m>":"Kowtow as a sign of respect, or dogeza for apology",
    u"m\(__\)m":"Kowtow as a sign of respect, or dogeza for apology",
    u"m\(_ _\)m":"Kowtow as a sign of respect, or dogeza for apology",
    u"\('_'\)":"Sad or Crying",
    u"\(/_;\)":"Sad or Crying",
    u"\(T_T\) \(;_;\)":"Sad or Crying",
    u"\(;_;":"Sad of Crying",
    u"\(;_:\)":"Sad or Crying",
    u"\(;O;\)":"Sad or Crying",
    u"\(:_;\)":"Sad or Crying",
    u"\(ToT\)":"Sad or Crying",
    u";_;":"Sad or Crying",
    u";-;":"Sad or Crying",
    u";n;":"Sad or Crying",
    u";;":"Sad or Crying",
    u"Q\.Q":"Sad or Crying",
    u"T\.T":"Sad or Crying",
    u"QQ":"Sad or Crying",
    u"Q_Q":"Sad or Crying",
    u"\(-\.-\)":"Shame",
    u"\(-_-\)":"Shame",
    u"\(一一\)":"Shame",
    u"\(；一_一\)":"Shame",
    u"\(=_=\)":"Tired",
    u"\(=\^\·\^=\)":"cat",
    u"\(=\^\·\·\^=\)":"cat",
    u"=_\^=	":"cat",
    u"\(\.\.\)":"Looking down",
    u"\(\._\.\)":"Looking down",
    u"\^m\^":"Giggling with hand covering mouth",
    u"\(\・\・?":"Confusion",
    u"\(?_?\)":"Confusion",
    u">\^_\^<":"Normal Laugh",
    u"<\^!\^>":"Normal Laugh",
    u"\^/\^":"Normal Laugh",
    u"\（\*\^_\^\*）" :"Normal Laugh",
    u"\(\^<\^\) \(\^\.\^\)":"Normal Laugh",
    u"\(^\^\)":"Normal Laugh",
    u"\(\^\.\^\)":"Normal Laugh",
    u"\(\^_\^\.\)":"Normal Laugh",
    u"\(\^_\^\)":"Normal Laugh",
    u"\(\^\^\)":"Normal Laugh",
    u"\(\^J\^\)":"Normal Laugh",
    u"\(\*\^\.\^\*\)":"Normal Laugh",
    u"\(\^—\^\）":"Normal Laugh",
    u"\(#\^\.\^#\)":"Normal Laugh",
    u"\（\^—\^\）":"Waving",
    u"\(;_;\)/~~~":"Waving",
    u"\(\^\.\^\)/~~~":"Waving",
    u"\(-_-\)/~~~ \($\·\·\)/~~~":"Waving",
    u"\(T_T\)/~~~":"Waving",
    u"\(ToT\)/~~~":"Waving",
    u"\(\*\^0\^\*\)":"Excited",
    u"\(\*_\*\)":"Amazed",
    u"\(\*_\*;":"Amazed",
    u"\(\+_\+\) \(@_@\)":"Amazed",
    u"\(\*\^\^\)v":"Laughing,Cheerful",
    u"\(\^_\^\)v":"Laughing,Cheerful",
    u"\(\(d[-_-]b\)\)":"Headphones,Listening to music",
    u'\(-"-\)':"Worried",
    u"\(ーー;\)":"Worried",
    u"\(\^0_0\^\)":"Eyeglasses",
    u"\(\＾ｖ\＾\)":"Happy",
    u"\(\＾ｕ\＾\)":"Happy",
    u"\(\^\)o\(\^\)":"Happy",
    u"\(\^O\^\)":"Happy",
    u"\(\^o\^\)":"Happy",
    u"\)\^o\^\(":"Happy",
    u":O o_O":"Surprised",
    u"o_0":"Surprised",
    u"o\.O":"Surpised",
    u"\(o\.o\)":"Surprised",
    u"oO":"Surprised",
    u"\(\*￣m￣\)":"Dissatisfied",
    u"\(‘A`\)":"Snubbed or Deflated"
}

chat_words_str = """
afaik=As far as i know
afk=away from keyboard
asap=as soon as possible
atk=at the keyboard
atm=at the moment
a3=anytime, anywhere, anyplace
bak=back at keyboard
bbl=be back later
bbs=be back soon
bfn=bye for now
b4n=bye for now
brb=be right back
brt=be right there
btw=by the way
b4=before
b4n=bye for now
cu=see you
cul8r=see you later
cya=see you
faq=frequently asked questions
fc=fingers crossed
fwiw=for what it's worth
fyi=for your information
gal=get a life
gg=good game
gn=good night
gmta=great minds think alike
gr8=great!
g9=genius
ic=I see
icq=I seek you (also a chat program)
ilu=ilu: I love you
imho=in my honest/humble opinion
imo=in my opinion
iow=in other words
irl=in real life
kiss=keep it simple, stupid
ldr=long distance relationship
lmao=laugh my a.. off
lol=laughing out loud
ltns=long time no see
l8r=later
mte=my thoughts exactly
m8=mate
nrn=no reply necessary
oic=oh I see
pita=pain in the a..
prt=party
prw=parents are watching
rofl=rolling on the floor laughing
roflol=rolling on the floor laughing out loud
rotflmao=rolling on the floor laughing my a.. off
sk8=skate
stats=your sex and age
asl=age, sex, location
thx=thank you
ttfn=ta-ta for now!
ttyl=talk to you later
u=you
u2=you too
u4e=yours for ever
wb=welcome back
wtf=what the f...
wtg=way to go!
wuf=where are you from?
w8=wait...
7k=sick:-d laugher
"""

In [20]:
#Remove emoticons
def remove_emoticons(text):
    emoticon_pattern = re.compile(u'(' + u'|'.join(k for k in EMOTICONS) + u')')
    return emoticon_pattern.sub(r'', text)

#Remove emojis
def remove_emoji(string):
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    return emoji_pattern.sub(r'', string)

In [21]:
# Chat word conversion
def build_chat_word_mapping(chat_words_str):
    mapping = {}
    cw_set = set()
    for line in chat_words_str.split('\n'):
        if line != '':
            cw, cw_description = line.split('=')[0], line.split('=')[1]
            cw_set.add(cw)
            mapping[cw] = cw_description
    return mapping, cw_set

CW_MAPPING, CW_SET = build_chat_word_mapping(chat_words_str)

def convert_chat_word(text):
    new_text = []
    for word in text.split():
        if word.upper() in CW_SET:
              new_text.append(CW_MAPPING[word.upper()])
        else:
              new_text.append(word)
    return ' '.join(new_text)

In [22]:
#Applying cleaning methods.
tqdm.pandas()
df = remove_empty(df)
df['reddit_text']= df['reddit_text'].progress_apply(remove_url)
df['reddit_text']= df['reddit_text'].progress_apply(preprocess_text)
df['reddit_text']= df['reddit_text'].progress_apply(remove_emoticons)
df['reddit_text']= df['reddit_text'].progress_apply(remove_emoji)
df['reddit_text']= df['reddit_text'].progress_apply(convert_chat_word)
df = remove_empty(df)

100%|██████████| 102876/102876 [00:00<00:00, 221366.55it/s]
100%|██████████| 102876/102876 [00:00<00:00, 604121.56it/s]
100%|██████████| 102876/102876 [00:21<00:00, 4690.43it/s]
100%|██████████| 102876/102876 [00:00<00:00, 121445.13it/s]
100%|██████████| 102876/102876 [00:01<00:00, 71691.10it/s]


In [23]:
df.sample(20)

Unnamed: 0,reddit_text,reddit_subreddit
90533,Yea who cares,BestBuyWorkers
64972,"Remember too (and I was just told this, so if ...",cabincrewcareers
36307,I provided the same answer and also got a CJO....,cabincrewcareers
101006,thank you!,KrakenSupport
10847,Congrats!! Autotechs were always severely unde...,BestBuyWorkers
91739,"Yeah, totally a thing. Customers are already c...",TalesFromYourBank
90222,YAYAYAYA CONGRATS,PaneraEmployees
28527,I appreciate the support! I’m hoping that I’m ...,cabincrewcareers
69093,"Sorry to hear that u/UltimaWizard,We can under...",KrakenSupport
57228,Never a problem with kraken since 2017,KrakenSupport


In [24]:
# Saving cleaned data for potential future use
# df.to_csv('processeddata.csv', index=True)

## 2. Using SBERT to rank the reddit comments

In [25]:
#Defining SBERT model for generating sentence embeddings
sentence_model = SentenceTransformer("thenlper/gte-large")

def get_sentence_embedding(text):
    if not text.strip():
    #.strip() gets rid of new lines
        print("Attempted to get embedding for empty text.")
        return []

    embedding = sentence_model.encode(text)

    return embedding.tolist()

def cos_angle(v, w):
#inputs v, w are vectors in 1-dimensional tensors (np.array(list))
    v = v.reshape(1,-1)
    w = w.reshape(1,-1)
    return cosine_similarity(v, w)

def inn(v, w):
    return v @ w

In [26]:
# loading or generating embeddings
embedding_results = []

for col in tqdm(selected_subreddts, position=1):
    dfcol = df[df['reddit_subreddit']==col]
    dfcol = dfcol.sort_values(by='reddit_text') # sort them by reddit texts
    dfcol = dfcol.reset_index().drop(columns='index') # resetting indices
    if os.path.exists(col+'_embeddings.pt'):
        textembedding = torch.load(col + '_embeddings.pt')
        dfcol = dfcol.assign(textembedding = textembedding)
    else:
        textembedding = [get_sentence_embedding(x) for x in tqdm(dfcol['reddit_text'], position=0)] #compute sentence embeddings
        torch.save(textembedding, col + '_embeddings.pt') #save embeddings for potential use later
        dfcol = dfcol.assign(textembedding = textembedding)
        dfcol.to_csv(col+'.csv', index=True) #save dataframe for potential use later

    d = {
    "subreddit": col,
    "dataframe": dfcol,
    "average_embedding": np.average(textembedding,axis = 0)}

    embedding_results.append(d)

db = pd.DataFrame(embedding_results)


  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:05<00:40,  5.00s/it][A
 22%|██▏       | 2/9 [00:09<00:31,  4.57s/it][A
 33%|███▎      | 3/9 [00:12<00:23,  3.98s/it][A
 44%|████▍     | 4/9 [00:15<00:18,  3.61s/it][A
 56%|█████▌    | 5/9 [00:17<00:12,  3.04s/it][A
 67%|██████▋   | 6/9 [00:18<00:07,  2.36s/it][A
 78%|███████▊  | 7/9 [00:19<00:03,  1.80s/it][A
100%|██████████| 9/9 [00:19<00:00,  2.20s/it]


In [27]:
db.head(10)

Unnamed: 0,subreddit,dataframe,average_embedding
0,TalesFromYourBank,r...,"[-0.0051973616468823395, -0.00625029310158874,..."
1,cabincrewcareers,r...,"[0.0004317182095134985, -0.00513778408053003, ..."
2,Chase,r...,"[-0.011372712865389696, -0.0010401242595671806..."
3,KrakenSupport,r...,"[-0.0017265008515737084, -0.001783644876558678..."
4,WalmartEmployees,r...,"[-0.008500096951390947, -0.0024834221175715956..."
5,BestBuyWorkers,re...,"[-0.00587379905436581, -0.005191515158705407, ..."
6,RiteAid,re...,"[-0.005485486050302116, -0.0022910498164752857..."
7,PaneraEmployees,re...,"[-0.006839219016928381, -0.002444414936333884,..."
8,FedEmployees,red...,"[-0.00456042588506361, -0.006909223673574161, ..."


In [28]:
clusterered_embeddings = []

num_clusters = 4
clustering_model = KMeans(n_clusters=num_clusters, random_state=0)

for col in tqdm(selected_subreddts):
    dfcol = df[df['reddit_subreddit']==col]
    dfcol = dfcol.sort_values(by='reddit_text') #sort them by reddit texts
    dfcol = dfcol.reset_index().drop(columns='index') #resetting indices

    textembedding = torch.load(col + '_embeddings.pt')
    dfcol = dfcol.assign(textembedding = textembedding) #create 'textembedding' column and store textembedding in it

    X = preprocessing.normalize(textembedding)
    clustering_model.fit(X)
    dfcol['Cluster ID'] = pd.Series(clustering_model.fit_predict(X))

    for i in range(num_clusters):
        temp = dfcol[dfcol['Cluster ID'] == i]
        d = {
        "subreddit": col + '-Cluster' + str(i),
        "dataframe": temp,
        "average_embedding": np.average([x for x in temp['textembedding']], axis = 0)}
        clusterered_embeddings.append(d)


db_cluster = pd.DataFrame(clusterered_embeddings)

100%|██████████| 9/9 [02:32<00:00, 17.00s/it]


In [29]:
db_cluster.head(10)

Unnamed: 0,subreddit,dataframe,average_embedding
0,TalesFromYourBank-Cluster0,r...,"[-0.006795902906831044, -0.005613459694692606,..."
1,TalesFromYourBank-Cluster1,r...,"[-0.003341804241755992, -0.007956792195795129,..."
2,TalesFromYourBank-Cluster2,r...,"[-0.005708994473888968, -0.009339751483220488,..."
3,TalesFromYourBank-Cluster3,r...,"[-0.00402720857133523, -0.0032547114839317163,..."
4,cabincrewcareers-Cluster0,r...,"[-0.0022175819854195877, -0.003243233466653384..."
5,cabincrewcareers-Cluster1,r...,"[0.008595544593372435, -0.00520953808124305, -..."
6,cabincrewcareers-Cluster2,r...,"[-0.008854166506720093, 0.0009743460049588485,..."
7,cabincrewcareers-Cluster3,r...,"[-0.0010806543073500457, -0.010234497415531153..."
8,Chase-Cluster0,r...,"[-0.010093581000317087, 0.005075598923411346, ..."
9,Chase-Cluster1,r...,"[-0.014370608862751497, -0.0035616086849616295..."


In [30]:
print("Total number of clusters: " + str(len(db_cluster)))

Total number of clusters: 36


In [31]:
db_cluster.to_csv('clustereddata.csv')

In [32]:
query = "How many PTOs does a regular employee have a year?"
query_vector = get_sentence_embedding(query)
query_vector = np.array(query_vector)

In [33]:
# Assigning average embedding and finding the closest subreddit -- create new col 'query_to_average' and then sort by the new column
db = db.assign(query_to_average = [cos_angle(query_vector, x)[0][0] for x in db['average_embedding']]).sort_values('query_to_average', ascending = False)
db = db.reset_index().drop(columns='index') #resetting indices
db.head(10)

Unnamed: 0,subreddit,dataframe,average_embedding,query_to_average
0,FedEmployees,red...,"[-0.00456042588506361, -0.006909223673574161, ...",0.869348
1,WalmartEmployees,r...,"[-0.008500096951390947, -0.0024834221175715956...",0.863571
2,BestBuyWorkers,re...,"[-0.00587379905436581, -0.005191515158705407, ...",0.860327
3,RiteAid,re...,"[-0.005485486050302116, -0.0022910498164752857...",0.857882
4,PaneraEmployees,re...,"[-0.006839219016928381, -0.002444414936333884,...",0.853539
5,cabincrewcareers,r...,"[0.0004317182095134985, -0.00513778408053003, ...",0.848919
6,TalesFromYourBank,r...,"[-0.0051973616468823395, -0.00625029310158874,...",0.842512
7,Chase,r...,"[-0.011372712865389696, -0.0010401242595671806...",0.820753
8,KrakenSupport,r...,"[-0.0017265008515737084, -0.001783644876558678...",0.80038


In [34]:
# Assigning average embedding to each cluster and finding the closest clusters
db_cluster = db_cluster.assign(query_to_average = [cos_angle(query_vector, x)[0][0] for x in db_cluster['average_embedding']]).sort_values('query_to_average', ascending = False)
db_cluster = db_cluster.reset_index().drop(columns='index') #resetting indices
db_cluster.head(10)

Unnamed: 0,subreddit,dataframe,average_embedding,query_to_average
0,BestBuyWorkers-Cluster1,re...,"[-0.000598085417866919, -0.00867528970097949, ...",0.877849
1,RiteAid-Cluster3,re...,"[-0.00351916891331693, -0.00755134534602435, -...",0.873159
2,WalmartEmployees-Cluster1,r...,"[-0.00618920366138775, -0.0016835359608260514,...",0.871133
3,PaneraEmployees-Cluster1,re...,"[-0.0024286220564592093, -0.004583379318224034...",0.866717
4,FedEmployees-Cluster1,red...,"[-0.0006124643553953825, -0.008996817065376996...",0.865469
5,WalmartEmployees-Cluster0,r...,"[-0.008436679147904334, -0.007346286229310813,...",0.864558
6,FedEmployees-Cluster0,red...,"[-0.004843487278226665, -0.016834775137520113,...",0.862534
7,FedEmployees-Cluster2,red...,"[-0.007335492723116962, -0.0042676048333911846...",0.851879
8,TalesFromYourBank-Cluster1,r...,"[-0.003341804241755992, -0.007956792195795129,...",0.849788
9,cabincrewcareers-Cluster0,r...,"[-0.0022175819854195877, -0.003243233466653384...",0.847042


In [35]:
df_bestmatch_nocluster = pd.concat([db.iloc[i].get('dataframe') for i in range(3)]) #top 3 subreddits
df_bestmatch_nocluster

Unnamed: 0,reddit_text,reddit_subreddit,textembedding
0,**Personal Investment Performance (PIP** - The...,FedEmployees,"[-0.023384714499115944, -0.0011062192497774959..."
1,**[More Lifecycle Funds will be available July...,FedEmployees,"[0.009730224497616291, 0.01290331780910492, -0..."
2,"/u/CH1EFK1NGD0M, I have found an error in your...",FedEmployees,"[-0.001974150538444519, -0.029495904222130775,..."
3,1yr from date you started the position in that...,FedEmployees,"[-0.011592132970690727, 0.018938768655061722, ..."
4,21% gains,FedEmployees,"[-0.00389385549351573, 0.02789214625954628, -0..."
...,...,...,...
5368,“Who are you to deny the return of the PS5 box...,BestBuyWorkers,"[-0.008802000433206558, -0.010793417692184448,..."
5369,“the narrative” Lmao or it’s just likely that ...,BestBuyWorkers,"[-0.008601265959441662, -0.02134491316974163, ..."
5370,… yes? Are you one of the programmers did I hu...,BestBuyWorkers,"[0.02130362205207348, -0.017897509038448334, -..."
5371,🤣🤣 thank you for the advice. Mind elaborating ...,BestBuyWorkers,"[-0.0002998623822350055, 0.0027278345078229904..."


In [36]:
df_bestmatch_cluster = pd.concat([db_cluster.iloc[i].get('dataframe') for i in range(3)]) #top 3 clusters
df_bestmatch_cluster

Unnamed: 0,reddit_text,reddit_subreddit,textembedding,Cluster ID
16,"""Anonymous"", sorry to give you the bad news bu...",BestBuyWorkers,"[-0.00750601664185524, -0.014071100391447544, ...",1
19,"""Hey guys, I paid for the pizza out of my own ...",BestBuyWorkers,"[0.008565056137740612, -0.010689752176404, -0....",1
23,"""Steer clear of post-military employees, speci...",BestBuyWorkers,"[0.003716543084010482, -0.009089987725019455, ...",1
27,"""new structure"" has increased the number of ca...",BestBuyWorkers,"[0.009708089753985405, -0.013909395784139633, ...",1
36,(Happy Cake Day!This is the same issue that ou...,BestBuyWorkers,"[0.018204687163233757, -0.0037236677017062902,...",1
...,...,...,...,...
10213,you wont be fired. if your HR person is there ...,WalmartEmployees,"[-0.022134169936180115, -0.013656818307936192,...",1
10215,you're good. you dont even have to follow what...,WalmartEmployees,"[-0.0022417991422116756, 0.014024728909134865,...",1
10222,"y’all i wish lmfao. however, how long have you...",WalmartEmployees,"[-0.003708786331117153, 0.006598129868507385, ...",1
10224,Ìf you don't go in next shift you're definitel...,WalmartEmployees,"[-0.007710224948823452, -0.01927049458026886, ...",1


In [37]:
# Without clustering, the runtime is over 1s
df_bestmatch_nocluster = df_bestmatch_nocluster.assign(cos_angles = [cos_angle(query_vector, np.array(x))[0][0] for x in df_bestmatch_nocluster['textembedding']]).sort_values('cos_angles', ascending = False)

In [38]:
df_bestmatch_nocluster.head(20)

Unnamed: 0,reddit_text,reddit_subreddit,textembedding,cos_angles
5678,"Not sure about PTO, but the most ppto you can ...",WalmartEmployees,"[-0.017012735828757286, 0.0018752547912299633,...",0.881989
301,"All associates earn PPTO, which is intended fo...",WalmartEmployees,"[-0.022059407085180283, -0.001781671424396336,...",0.879523
1619,Full timers generally earn pto faster. But mos...,WalmartEmployees,"[-0.008347813971340656, 0.010986113920807838, ...",0.879252
4867,Just went through onboarding (for Walmart dist...,WalmartEmployees,"[-0.0019015477737411857, -0.009029118344187737...",0.87848
2620,I get 21 mins of ppto earned per shift (10 hrs...,WalmartEmployees,"[0.0035366499796509743, 0.007027538493275642, ...",0.877869
2275,It is like pto,BestBuyWorkers,"[-0.01784772425889969, -0.026210324838757515, ...",0.877719
8914,Yeah I'm aware of the benefits and pto. I just...,WalmartEmployees,"[0.004729976411908865, -0.006030676886439323, ...",0.87631
2811,Most places I've seen typically provide 0 PTO ...,BestBuyWorkers,"[-0.0130627965554595, -0.01606047712266445, -0...",0.875701
5655,Not everyone gets ppto at the same rate. You s...,WalmartEmployees,"[-0.008390869945287704, 0.013110107742249966, ...",0.874638
10041,"overtime, bonuses, PTO",WalmartEmployees,"[-0.003883741097524762, -0.011234667152166367,...",0.873787


In [39]:
# By clustering, we cut down our runtown to under 1s
df_bestmatch_cluster = df_bestmatch_cluster.assign(cos_angles = [cos_angle(query_vector, np.array(x))[0][0] for x in df_bestmatch_cluster['textembedding']]).sort_values('cos_angles', ascending = False)

In [40]:
df_bestmatch_cluster.head(20)

Unnamed: 0,reddit_text,reddit_subreddit,textembedding,Cluster ID,cos_angles
5678,"Not sure about PTO, but the most ppto you can ...",WalmartEmployees,"[-0.017012735828757286, 0.0018752547912299633,...",1,0.881989
301,"All associates earn PPTO, which is intended fo...",WalmartEmployees,"[-0.022059407085180283, -0.001781671424396336,...",1,0.879523
1619,Full timers generally earn pto faster. But mos...,WalmartEmployees,"[-0.008347813971340656, 0.010986113920807838, ...",1,0.879252
4867,Just went through onboarding (for Walmart dist...,WalmartEmployees,"[-0.0019015477737411857, -0.009029118344187737...",1,0.87848
2620,I get 21 mins of ppto earned per shift (10 hrs...,WalmartEmployees,"[0.0035366499796509743, 0.007027538493275642, ...",1,0.877869
2275,It is like pto,BestBuyWorkers,"[-0.01784772425889969, -0.026210324838757515, ...",1,0.877719
8914,Yeah I'm aware of the benefits and pto. I just...,WalmartEmployees,"[0.004729976411908865, -0.006030676886439323, ...",1,0.87631
2811,Most places I've seen typically provide 0 PTO ...,BestBuyWorkers,"[-0.0130627965554595, -0.01606047712266445, -0...",1,0.875701
5655,Not everyone gets ppto at the same rate. You s...,WalmartEmployees,"[-0.008390869945287704, 0.013110107742249966, ...",1,0.874638
10041,"overtime, bonuses, PTO",WalmartEmployees,"[-0.003883741097524762, -0.011234667152166367,...",1,0.873787


## 3. Improving performance via synthetic query generation and averaging


In [41]:
# We generate synethetic queries to help improve the performance of our model
input = tokenizer.encode(query, return_tensors='pt')

outputs = model.generate(
    input_ids=input,
    max_length=16,
    do_sample=True,
    top_p=0.95,
    num_return_sequences=3)

print("Original query:")

print(query)

print("\nGenerated Synthetic Queries:")

for i in range(len(outputs)):
    q = tokenizer.decode(outputs[i], skip_special_tokens=True)
    print(f'{i + 1}: {q}')

Original query:
How many PTOs does a regular employee have a year?

Generated Synthetic Queries:
1: how many pto does an employee get
2: how many pto hours does an employee get?
3: how many ptos do employees have


In [42]:
chosen_query_vectors = [get_sentence_embedding(tokenizer.decode(x, skip_special_tokens=True)) for x in outputs]
chosen_query_vectors.append(query_vector)

In [43]:
avg_cos_angles = []

for x in df_bestmatch_cluster['textembedding']:
    avg_cos_angle = 0
    for v in chosen_query_vectors:
        avg_cos_angle += cos_angle(np.array(v), np.array(x))[0][0]
    avg_cos_angles.append(avg_cos_angle/len(chosen_query_vectors))


In [44]:
df_bestmatch_cluster = df_bestmatch_cluster.assign(avg_cos_angles = avg_cos_angles).sort_values('cos_angles', ascending = False)
df_bestmatch_by_avg = df_bestmatch_cluster.assign(avg_cos_angles = avg_cos_angles).sort_values('avg_cos_angles', ascending = False)

In [45]:
df_bestmatch_cluster.head(20)

Unnamed: 0,reddit_text,reddit_subreddit,textembedding,Cluster ID,cos_angles,avg_cos_angles
5678,"Not sure about PTO, but the most ppto you can ...",WalmartEmployees,"[-0.017012735828757286, 0.0018752547912299633,...",1,0.881989,0.895375
301,"All associates earn PPTO, which is intended fo...",WalmartEmployees,"[-0.022059407085180283, -0.001781671424396336,...",1,0.879523,0.886344
1619,Full timers generally earn pto faster. But mos...,WalmartEmployees,"[-0.008347813971340656, 0.010986113920807838, ...",1,0.879252,0.888278
4867,Just went through onboarding (for Walmart dist...,WalmartEmployees,"[-0.0019015477737411857, -0.009029118344187737...",1,0.87848,0.881819
2620,I get 21 mins of ppto earned per shift (10 hrs...,WalmartEmployees,"[0.0035366499796509743, 0.007027538493275642, ...",1,0.877869,0.889154
2275,It is like pto,BestBuyWorkers,"[-0.01784772425889969, -0.026210324838757515, ...",1,0.877719,0.881516
8914,Yeah I'm aware of the benefits and pto. I just...,WalmartEmployees,"[0.004729976411908865, -0.006030676886439323, ...",1,0.87631,0.879444
2811,Most places I've seen typically provide 0 PTO ...,BestBuyWorkers,"[-0.0130627965554595, -0.01606047712266445, -0...",1,0.875701,0.88036
5655,Not everyone gets ppto at the same rate. You s...,WalmartEmployees,"[-0.008390869945287704, 0.013110107742249966, ...",1,0.874638,0.888211
10041,"overtime, bonuses, PTO",WalmartEmployees,"[-0.003883741097524762, -0.011234667152166367,...",1,0.873787,0.88864


In [46]:
df_bestmatch_by_avg.head(20)

Unnamed: 0,reddit_text,reddit_subreddit,textembedding,Cluster ID,cos_angles,avg_cos_angles
5678,"Not sure about PTO, but the most ppto you can ...",WalmartEmployees,"[-0.017012735828757286, 0.0018752547912299633,...",1,0.881989,0.895375
2620,I get 21 mins of ppto earned per shift (10 hrs...,WalmartEmployees,"[0.0035366499796509743, 0.007027538493275642, ...",1,0.877869,0.889154
10041,"overtime, bonuses, PTO",WalmartEmployees,"[-0.003883741097524762, -0.011234667152166367,...",1,0.873787,0.88864
1619,Full timers generally earn pto faster. But mos...,WalmartEmployees,"[-0.008347813971340656, 0.010986113920807838, ...",1,0.879252,0.888278
5655,Not everyone gets ppto at the same rate. You s...,WalmartEmployees,"[-0.008390869945287704, 0.013110107742249966, ...",1,0.874638,0.888211
301,"All associates earn PPTO, which is intended fo...",WalmartEmployees,"[-0.022059407085180283, -0.001781671424396336,...",1,0.879523,0.886344
1517,It varies greatly by state/location and store ...,RiteAid,"[0.0023047656286507845, 0.0036134396214038134,...",3,0.866713,0.883882
4867,Just went through onboarding (for Walmart dist...,WalmartEmployees,"[-0.0019015477737411857, -0.009029118344187737...",1,0.87848,0.881819
2275,It is like pto,BestBuyWorkers,"[-0.01784772425889969, -0.026210324838757515, ...",1,0.877719,0.881516
2227,"How the hell do they Calculate PTO? Cause, at ...",WalmartEmployees,"[-0.0203511044383049, -0.006431135348975658, -...",1,0.872277,0.881054


## 4. Feeding the ranked and re-ranked sentences to LLM

We introduce a function to generate a prompt for our LLM from top $k$ comments from ranked (or re-ranked comments):

In [47]:
def rag_prompt(df_col, k):
    information_to_feed = ""
    for n, text in zip(range(k), df_col):
        information_to_feed += f"{n+1}: " + text + "\n"
    # concatenate the first top k comments
    combined_information = f"\nQuery: {query}\n\nAnswer the above query by only using the following:\n\n{information_to_feed}\n\nLLM Response:"

    return combined_information

In [48]:
prompt_ranked = rag_prompt(df_bestmatch_cluster['reddit_text'], 5) # prompt generated from top 5 commenets ranked by cos sim to the original query
prompt_reranked = rag_prompt(df_bestmatch_by_avg['reddit_text'], 5) # prompt generated from top 5 commenets reranked by cos sim to the original query + alternative queries

**Warning**. To reproduce the following, one may need HuggingFace API Key (which is free for the purpose of this notebook).

In [49]:
access_token = "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"

In [60]:
auto_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it" , device_map="auto", token=access_token)
# CPU Enabled uncomment below
# model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
# GPU Enabled use below
auto_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", token=access_token)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



We write a function that produces a response from our LLM given a prompt.

In [61]:
def llm(prompt):
    input_ids = auto_tokenizer(prompt, return_tensors="pt")
    response = auto_model.generate(**input_ids, max_new_tokens=512)
    return auto_tokenizer.decode(response[0])

The following takes a while, but we believe that this can be improved by using a higher-performing LLM:

In [62]:
llm_answer1 = llm(prompt_ranked)
llm_answer2 = llm(prompt_reranked)

In [63]:
print("Using top 5 commenets ranked by cos sim to the original query \n\n", llm_answer1)

Using top 5 commenets ranked by cos sim to the original query 

 <bos>
Query: How many PTOs does a regular employee have a year?

Answer the above query by only using the following:

1: Not sure about PTO, but the most ppto you can earn in a year is 48 hours, with the exception of a couple of states which are unlimited by state law.
2: All associates earn PPTO, which is intended for emergencies (car won't start, you're sick, etcFull-time associates also earn PTO, which is intended for scheduled absences (doctor appointments, getting your drivers license renewed, vacations, etc. **Part-time associates do not earn PTO until they have been with the company for 3 years,** because part-time associates should be able to schedule their appointments and other errands on their days off.
3: Full timers generally earn pto faster. But most part timers actually earn ppto faster. For most associates in most states, ppto is 1 hour for every 30 hours worked. In most states the ppto accrual rate doesn’

In [64]:
print("Using top 5 commenets reranked by cos sim to the original query + alternative queries \n\n", llm_answer2)

Using top 5 commenets reranked by cos sim to the original query + alternative queries 

 <bos>
Query: How many PTOs does a regular employee have a year?

Answer the above query by only using the following:

1: Not sure about PTO, but the most ppto you can earn in a year is 48 hours, with the exception of a couple of states which are unlimited by state law.
2: I get 21 mins of ppto earned per shift (10 hrs and 19 mins of pto ... some states have different rules/laws but so I earn 84 minutes of ppto a week and 76 mins of regular
3: overtime, bonuses, PTO
4: Full timers generally earn pto faster. But most part timers actually earn ppto faster. For most associates in most states, ppto is 1 hour for every 30 hours worked. In most states the ppto accrual rate doesn’t change over time (aside from with part timers, who get an increase at their 3rd year
5: Not everyone gets ppto at the same rate. You start earning from day one and it’s available to use after 90 days. If you’re part time and les

## 5. Evaluation of retrieval

Note that it is rather difficult to say which LLM responses are better. Moreover, we note that our goal is NOT to get the answer that is absolutely correct but a relevant one among the reddit comments that we put in. For example, the answer may change over time, unless we update the input comments.

Hence, we use use both of the LLM responses as ground truths and compare the top 50 retrievals from the two methods:
* Method 1: Naive RAG using cosine similairties against the original query
* Method 2: Not-so-naive RAG using average cosine similairties against multiple similar queries, including the original one

The following function separates the LLM response from an LLM answer:

In [65]:
def llm_answer_to_response(answer):
    start = answer.find("LLM Response:")
    end = answer.find("<eos>")
    return answer[start+len("LLM Response:")+1:end]

print(llm_answer_to_response(llm_answer1))
print(llm_answer_to_response(llm_answer2))

Regular employees typically earn 48 hours of paid time off per year, with the exception of a few states that have unlimited PTO by state law.
According to the passage, regular employees are entitled to 1 hour of paid time off per 30 hours worked.


In [66]:
truth_1 = llm_answer_to_response(llm_answer1)
truth_2 = llm_answer_to_response(llm_answer2)

In [67]:
vectors_1 = []
vectors_2 = []

for i in range(50):
    vectors_1.append(df_bestmatch_cluster.get('textembedding').iloc[i])
    vectors_2.append(df_bestmatch_by_avg.get('textembedding').iloc[i])

### Evaluation metric 1: cosine precision
The following is a function with which we evaluate the retrieval from each method. Let $\boldsymbol{t}_1$ and $\boldsymbol{t}_2$ be the truth vectors. For each vector $\boldsymbol{v}$ from a batch, the cosine similarities $\cos(\boldsymbol{t}_1, \boldsymbol{v})$ and $\cos(\boldsymbol{t}_2, \boldsymbol{v})$ are in the interval $[-1, 1]$, but in all of our examples, we know they are in $[0, 1]$. We simply take the average of the two to measure how truthful $\boldsymbol{v}$ is. Note that the closer the average is to $1$, the more truthful $\boldsymbol{v}$ is.

Recall the definition of **precision**:
$$\mathrm{Precision} := \frac{\mathrm{Relevant \ retrieved \ instances}}{\mathrm{All \ retrieved \ instances}}.$$

Given a batch $B$, we define the **cosine precision** as follows:

$$\mathrm{Cosine \ Precision \ of } \ B := \frac{1}{2|B|}\sum_{\boldsymbol{v} \in B}  (\cos(\boldsymbol{t}_1, \boldsymbol{v}) + \cos(\boldsymbol{t}_2, \boldsymbol{v}))$$

In [68]:
def cos_precision(batch, t_1, t_2):
    t_1 = np.array(t_1)
    t_2 = np.array(t_2)

    sum = 0

    for v in batch:
        v = np.array(v)
        sum += (cos_angle(t_1, v) + cos_angle(t_2, v))
    return sum / (2*len(batch))

t_1 = get_sentence_embedding(truth_1)
t_2 = get_sentence_embedding(truth_2)

In [69]:
cos_precision(vectors_1, t_1, t_2)

array([[0.84906209]])

In [70]:
cos_precision(vectors_2, t_1, t_2)

array([[0.85509762]])

### Evaluation metric 2: ranked cosine precision

The following is a function that evaluates not only the retrieval, but also evaluates the ranking for the retrieved contexts.

Assume we retrieved $K$ comments in the context, ranked as $B = (x_1, \ldots, x_K)$.

We call the **precision at rank $m$** the cosine precision for the truncated context $B_m := (x_1, \ldots, x_m)$. And the ranked cosine precision is the average of these precisions.

$$
\text{Ranked Cosine Precision of } B := \frac{1}{K} \sum_{m = 1}^{K} \text{Cosine Precision of } B_m.
$$

Under this measurement, those comments ranked higher in the retrieved context will have a higher impact to the precision.

In [71]:
def cos_rank_precision(batch, t_1, t_2):
    sum = 0

    for m in range(1, len(batch)+1):
        sum += cos_precision(batch[:m], t_1, t_2)

    return sum / len(batch)

In [72]:
cos_rank_precision(vectors_1, t_1, t_2)

array([[0.85837544]])

In [73]:
cos_rank_precision(vectors_2, t_1, t_2)

array([[0.86432416]])

## 6. Conclusion

As we have seen in the example above, our averaging method improves the overall retrieval better by getting rid of possibly unrelated retrieved data by comparisions with multiple similar queries to the original one. The LLM API took some time, but it is evident that any stronger LLM we use would only make this process faster.