# Self-Guessing Hybrid Search (aka Part 3/3)

Assume we have a working recipe for hybrid search in the form `(query, [keyword]) => [(snippet, siilarity)]`.

Now the goal of this notebook is, _can we self-guess the keywords from the query and have the caller supply just the query?_

In [1]:
import os
from functools import partial
from dotenv import load_dotenv

_ = load_dotenv('.env')

In [2]:
import cassio

In [3]:
cassio.init(
    token=os.environ['ASTRA_DB_APPLICATION_TOKEN'],
    database_id=os.environ['ASTRA_DB_ID'],
    keyspace=os.environ.get('ASTRA_DB_KEYSPACE'),
)
session = cassio.config.resolve_session()
keyspace = cassio.config.resolve_keyspace()

In [4]:
import openai

embedding_model_name = "text-embedding-ada-002"

def get_embeddings(texts):
    result = openai.Embedding.create(
        input=texts,
        engine=embedding_model_name,
    )
    return [res.embedding for res in result.data]

### Recap

Let's use the latest "hybrid search" function from the previous investigation. One important point is that we'll proceed with the keywords in OR (i.e. a single keyword match suffices for a hit). This enables a meaningful contribution to the "score" from the keyword side; but most important, with the keywords being self-guessed, it protects somewhat from too-demanding "guesses" from the query.

**We have packaged the latest machinery from Part 2 into a Python module to reduce clutter, nothing new**

_(We just renamed the final hybrid-search function for convenience and made sure the DB parameters pass through the calls)_

In [5]:
from kw_hybrid_tools import hybrid_search_with_kw, show, keyword_similarity, sum_score_merger

That has the following signature:
```
def hybrid_search_with_kw(session, keyspace, get_embeddings, query, keywords=[],
                          top_k=3, kw_similarity_function=keyword_similarity,
                          score_merger_function=sum_score_merger, prefetch_factor=5):
    ...
```

Now, let's define a handy shortcut and run a sanity check (to be compared with the "QUERY7/KW7" run in the previous notebook):

In [6]:
hybrid_kw = partial(hybrid_search_with_kw, session=session, keyspace=keyspace, get_embeddings=get_embeddings)

KW0 = ['support', 'chat']
QUERY0 = "How come I cannot chat?"

print(f"[with safe prefetch] QUERY: '{QUERY0}', KEYWORDS: \'{', '.join(KW0)}\'")
show(hybrid_kw(query=QUERY0, keywords=KW0))

[with safe prefetch] QUERY: 'How come I cannot chat?', KEYWORDS: 'support, chat'
    [1] 0.96762 "I cannot open the support chat."
    [2] 0.96157 "I see no messages in the support chat."
    [3] 0.95499 "The support chat on the website is lagging."


## Guessing the keywords from the query

Let us consider first a simple and disappointing keyword-guesser function (there'll be several of them) later:

In [7]:
PUNKT = set('!,.?;\'"-+=/[]{}()\n')

def guess_kws_simple(query):
    _qry = ''.join([c for c in query if c not in PUNKT]).lower()
    return {
        w
        for w in _qry.split(' ')
        if w
        if len(w) > 4
    }

Rather crude, isn't it?

In [8]:
print(guess_kws_simple("The report due today is on Mia's desk, Benjamin!"))

{'report', 'today', 'benjamin'}


Let's start from this and repackage a keyword-guessing hybrid search function (again, we take advantage of the partialed shortcut to focus on the important bits):

In [9]:
def hybrid_guess(query, kw_guesser, top_k=3, kw_similarity_function=keyword_similarity,
                 score_merger_function=sum_score_merger, prefetch_factor=5):
    keywords = kw_guesser(query)
    return hybrid_kw(
        query=query,
        keywords=keywords,
        top_k=top_k,
        kw_similarity_function=kw_similarity_function,
        score_merger_function=score_merger_function,
        prefetch_factor=prefetch_factor,
    )

A little test with the crude guesser (and, let's not bother with the other settings now):

In [10]:
QUERY1 = "How come I cannot chat?"
print(f"QUERY: '{QUERY1}' [keywords={guess_kws_simple(QUERY1)}]")
show(hybrid_guess(QUERY1, kw_guesser=guess_kws_simple))

QUERY: 'How come I cannot chat?' [keywords={'cannot'}]
    [1] 0.96762 "I cannot open the support chat."
    [2] 0.95840 "I cannot speak with the support operator!"


But clearly this is not the best solution in general:

In [11]:
QUERY2 = "Do you currently have any offers?"
print(f"QUERY: '{QUERY2}' [keywords={guess_kws_simple(QUERY2)}]")
show(hybrid_guess(QUERY2, kw_guesser=guess_kws_simple))

QUERY3 = "Why does the site experience these lags?"
print(f"QUERY: '{QUERY3}' [keywords={guess_kws_simple(QUERY3)}]")
show(hybrid_guess(QUERY3, kw_guesser=guess_kws_simple))

QUERY: 'Do you currently have any offers?' [keywords={'currently', 'offers'}]
QUERY: 'Why does the site experience these lags?' [keywords={'experience', 'these'}]


_Note: with just "Why does the site lag?" you **would** get some results ... of course! No keywords found, and it falls back to ANN-only._

### Known keyword set, quick approaches

Suppose your knowledge of the problem domain lets you make an explicit list of the keywords you want to potentially use:

> You may achieve this "by hand", or by passing a random sample subset of the snippets to an LLM, ... or a combination of both.

In [12]:
available_keywords = set("buy gift discounts support operator chat message offer lag product payment process shop cart".split(" "))

Now very "cheap" but in some cases effective (and fast!) keyword extractors can be constructed.

_Note: sometimes these might yield false positives, such as "lag" being a substring of "flagged". The semantic side of the search would mostly take care of these, with the proper tuning in the score-merging phase._

In [13]:
def guess_kws_substr_from_set(query, kws=available_keywords):
    _qry = query.lower().strip()
    return {kw for kw in available_keywords if kw in _qry}

def guess_kws_tokens_from_set(query, kws=available_keywords):
    _qry = ''.join([c for c in query if c not in PUNKT]).lower()
    toks = {tk for tk in _qry.split(' ') if tk}
    return toks & kws

In [14]:
queries = [QUERY1, QUERY2, QUERY3]
guessers = [('SUBSTR, from set', guess_kws_substr_from_set), ('TOKENS, from set', guess_kws_tokens_from_set)]

for kw_g_name, kw_g in guessers:
    for qry in queries:
        print(f"\nKwGuesser=<{kw_g_name}>, QUERY='{qry}' [keywords={kw_g(qry)}]")
        show(hybrid_guess(qry, kw_guesser=kw_g))


KwGuesser=<SUBSTR, from set>, QUERY='How come I cannot chat?' [keywords={'chat'}]
    [1] 0.96762 "I cannot open the support chat."
    [2] 0.96608 "A message disappeared from the chat?"
    [3] 0.96157 "I see no messages in the support chat."

KwGuesser=<SUBSTR, from set>, QUERY='Do you currently have any offers?' [keywords={'offer'}]
    [1] 0.96833 "Is there any special offer today?"
    [2] 0.47047 "Are special offers available?"

KwGuesser=<SUBSTR, from set>, QUERY='Why does the site experience these lags?' [keywords={'lag'}]
    [1] 0.46385 "The support chat on the website is lagging."

KwGuesser=<TOKENS, from set>, QUERY='How come I cannot chat?' [keywords={'chat'}]
    [1] 0.96760 "I cannot open the support chat."
    [2] 0.96606 "A message disappeared from the chat?"
    [3] 0.96155 "I see no messages in the support chat."

KwGuesser=<TOKENS, from set>, QUERY='Do you currently have any offers?' [keywords=set()]
    [1] 0.97047 "Are special offers available?"
    [2] 0.96833 "

The substring mode seems to fare better: we get `"lag"` (which was too short in the crude approach) and we don't get confused by `"cannot"` or similar irrelevant things. As remarked earlier, however we may get false positives (especially if there aren't better vectors in the store, to climb to the top results and displace the intruders):

In [15]:
QUERY4 = "I have been flagged by an admin... what do I do?"
print(f"\nKwGuesser=<SUBSTR, from set>, QUERY='{QUERY4}' [keywords={guess_kws_substr_from_set(QUERY4)}]")
show(hybrid_guess(QUERY4, kw_guesser=guess_kws_substr_from_set))


KwGuesser=<SUBSTR, from set>, QUERY='I have been flagged by an admin... what do I do?' [keywords={'lag'}]
    [1] 0.43553 "The support chat on the website is lagging."


### Known keyword set, using AI

Of course, the next thing we try is to employ AI to nail the kewords for us. This, however, comes at a performance cost. Depending on the use cases, an additional delay of one second or more might be acceptable or not. Let's see what we can do, and defer timing the performance to a later section.

#### HuggingFace zero-shot classifier

There's a new parameter here, a threshold to accept the keyword from the classifier.

> Install pytorch. This is not covered in the `requirements.txt` since it's ... complicated. Please follow [this](https://pytorch.org/get-started/locally/).

In [25]:
from transformers import pipeline

hf_zs_classifier = pipeline("zero-shot-classification")

def guess_kws_hf_zs(query, kws=available_keywords, keyword_threshold=0.5):
    _kws = list(kws)
    result = hf_zs_classifier([query], _kws, multi_label=True)[0]
    return {
        keyword
        for keyword, kw_score in zip(result["labels"], result["scores"])
        if kw_score >= keyword_threshold
    }

No model was supplied, defaulted to facebook/bart-large-mnli and revision c626438 (https://huggingface.co/facebook/bart-large-mnli).
Using a pipeline without specifying a model name and revision in production is not recommended.


Just the keyword extraction in action:

In [26]:
queries = [QUERY1, QUERY2, QUERY3, QUERY4]

for qry in queries:
    print(f"QUERY='{qry}' [keywords={guess_kws_hf_zs(qry)}]")

QUERY='How come I cannot chat?' [keywords={'lag'}]
QUERY='Do you currently have any offers?' [keywords={'discounts', 'operator', 'message', 'offer'}]
QUERY='Why does the site experience these lags?' [keywords={'operator', 'lag'}]
QUERY='I have been flagged by an admin... what do I do?' [keywords={'operator', 'process', 'message', 'chat'}]


Remember you can set the extraction to be more or less generous by playing with the threshold:

In [28]:
print(f"QUERY='{QUERY4}', keywords by threshold for guess_kws_hf_zs:")
for kw_t in [0.3, 0.4, 0.5, 0.7, 0.8, 0.9]:
    print(f"    threshold={kw_t:0.2f} ==> [keywords={guess_kws_hf_zs(QUERY4, keyword_threshold=kw_t)}]")

QUERY='I have been flagged by an admin... what do I do?', keywords by threshold for guess_kws_hf_zs:
    threshold=0.30 ==> [keywords={'offer', 'cart', 'message', 'operator', 'process', 'lag', 'chat'}]
    threshold=0.40 ==> [keywords={'operator', 'process', 'message', 'chat'}]
    threshold=0.50 ==> [keywords={'operator', 'process', 'message', 'chat'}]
    threshold=0.70 ==> [keywords={'operator', 'process', 'message'}]
    threshold=0.80 ==> [keywords={'process', 'message'}]
    threshold=0.90 ==> [keywords=set()]


#### HuggingFace LLM to get the keywords from a set

This does not seem to lead to any quickly usable result. Abandoned for now (besides, it's slower than calling LLM services).

Keeping them for record, one attempt per cell

In [75]:
_ = '''

# FROM: https://huggingface.co/docs/transformers/v4.15.0/en/task_summary#text-generation


hf_tg_llm = pipeline("text-generation")

# just did a few tests with prompts, no luck
KW_EXTRACTION_PROMPT_TEMPLATE = """The relevant keywords extracted from "{query}" are: """
prompt = KW_EXTRACTION_PROMPT_TEMPLATE.format(query=QUERY3)

result = hf_tg_llm(prompt, max_length=50 + len(prompt), do_sample=False)
print(result[0]['generated_text'])
'''

In [76]:
_ = '''
# FROM: https://huggingface.co/docs/transformers/v4.15.0/en/task_summary#text-generation

from transformers import AutoModelForCausalLM, AutoTokenizer


model = AutoModelForCausalLM.from_pretrained("xlnet-base-cased")
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")

# Padding text helps XLNet with short prompts - proposed by Aman Rusia in https://github.com/rusiaaman/XLNet-gen#methodology
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""

prompt = f"The top search keyword extracted from \"{QUERY4}\" are ..."

inputs = tokenizer(PADDING_TEXT + prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
prompt_length = len(tokenizer.decode(inputs[0]))
outputs = model.generate(inputs, max_length=250, do_sample=True, top_p=0.95, top_k=60)
generated = prompt + tokenizer.decode(outputs[0])[prompt_length+1:]

print(generated)
'''

#### Using a greater LLM for keywords from a set

What if we swap HuggingFace's LLM with "powerful" ones, such as gpt3 from OpenAI, and have it extract our keywords from a set?

In [92]:
KW_EXTRACTION_PROMPT_TEMPLATE = """
You are to extract keywords from a query string for use in a keyword-based search engine.
Please output them in a comma-separated list.
Be very careful that keywords MUST be given in stemmed form: nouns are singular, verbs are in the infinite, and so on.
Do not exceed a dozen keyword. Keywords must not include whitespaces, i.e. they must be a single word.
Include proper nouns if relevant. Discard stop words, pronouns and generic verbs such as be, do, get and so on.

EXAMPLE QUERY: Does the site currently offer discounts? It featured them on the portal yesterday.
EXAMPLE KEYWORDS: site, offer, discount, feature, portal

QUERY STRING: {query}

KEYWORDS:"""

prompt = KW_EXTRACTION_PROMPT_TEMPLATE.format(query="Is Santa a user of your website? My friend assures me so")

completion_model_name = "gpt-3.5-turbo"

response = openai.ChatCompletion.create(
    model=completion_model_name,
    messages=[{"role": "user", "content": prompt}],
    temperature=0.0,
    max_tokens=20,
)
keywords = {
    tok
    for tok in (
        _tok.strip()
        for _tok in response.choices[0].message.content.lower().split(",")
    )
    if tok
}
print(keywords)

{'friend', 'santa', 'assure', 'user', 'website'}


Note that the quality of the output (the choice of keywords, but also the proper stemming and casing) heavily depend on engineering the right prompt.

### Open keyword set, using AI

We will test AI-powered open-set keyword extraction in a couple of possible ways

#### HuggingFace NER

In [34]:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer

# FROM: https://huggingface.co/docs/transformers/v4.15.0/en/task_summary#named-entity-recognition

model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

Downloading model.safetensors:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: d1723dcd-523d-434a-98c6-562f4fb9ec21)')' thrown while requesting HEAD https://huggingface.co/bert-base-cased/resolve/main/tokenizer_config.json


In [32]:
sequence = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, " \
           "therefore very close to the Manhattan Bridge."

inputs = tokenizer(sequence, return_tensors="pt")

tokens = inputs.tokens()
outputs = model(**inputs).logits
predictions = torch.argmax(outputs, dim=2)

In [33]:
for token, prediction in zip(tokens, predictions[0].numpy()):
    print((token, model.config.id2label[prediction]))

('[CLS]', 'O')
('Hu', 'I-ORG')
('##gging', 'I-ORG')
('Face', 'I-ORG')
('Inc', 'I-ORG')
('.', 'O')
('is', 'O')
('a', 'O')
('company', 'O')
('based', 'O')
('in', 'O')
('New', 'I-LOC')
('York', 'I-LOC')
('City', 'I-LOC')
('.', 'O')
('Its', 'O')
('headquarters', 'O')
('are', 'O')
('in', 'O')
('D', 'I-LOC')
('##UM', 'I-LOC')
('##BO', 'I-LOC')
(',', 'O')
('therefore', 'O')
('very', 'O')
('close', 'O')
('to', 'O')
('the', 'O')
('Manhattan', 'I-LOC')
('Bridge', 'I-LOC')
('.', 'O')
('[SEP]', 'O')
