<a href="https://colab.research.google.com/github/micahsand/msc_llm_aqe/blob/main/LLM_query_expansion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Import things, seed randoms, set up LLM and Pyterrier

In [None]:
# uncomment if running in colabs
!pip install python-terrier
#!pip install ir_datasets

Collecting python-terrier
  Downloading python-terrier-0.10.0.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting wget (from python-terrier)
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyjnius>=1.4.2 (from python-terrier)
  Downloading pyjnius-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting matchpy (from python-terrier)
  Downloading matchpy-0.5.5-py3-none-any.whl (69 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.6/69.6 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
Collecting deprecated (from python-terrier)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl (9.6 kB)
Collecting chest (from py

In [None]:
# uncomment if running on local machine
# import os
# os.environ["JAVA_HOME"] = "C:\Program Files\Java\jdk-21"

In [None]:
import pyterrier as pt
import pandas as pd
import numpy as np
import random
import io
import torch
from transformers import pipeline, set_seed
from itertools import islice, chain

In [None]:
set_seed(1234)
torch.random.seed = 1234
np.random.seed(1234)
random.seed(1234)

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# Initialise LLM pipeline
ll_model = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

In [None]:
pt.init(boot_packages=["com.github.terrierteam:terrier-prf:-SNAPSHOT"])

terrier-assemblies 5.8 jar-with-dependencies not found, downloading to /root/.pyterrier...
Done
terrier-python-helper 0.0.8 jar not found, downloading to /root/.pyterrier...
Done
terrier-prf -SNAPSHOT jar not found, downloading to /root/.pyterrier...
Done


PyTerrier 0.10.0 has loaded Terrier 5.8 (built by craigm on 2023-11-01 18:05) and terrier-helper 0.0.8



In [None]:
dataset = pt.get_dataset('irds:msmarco-document/orcas')

In [None]:
def get_fixed_queries_iterator(terrier_dataset):
    ds = terrier_dataset.irds_ref()
    assert ds.has_queries(), f"{self._irds_id} doesn't support get_topics"
    iterator = ds.queries_iter()
    iterator.line_iter.stream = io.TextIOWrapper(iterator.line_iter.ctxt.enter_context(iterator.line_iter.dlc.stream()), errors='replace')
    return iterator

def downsampled_iterator(iterator, sample_rate, seed=1234):
    random.seed(seed)
    for item in iterator:
        if random.uniform(0, 1) < sample_rate:
            yield item

In [None]:
# train/test split? ~ downsampled iterator
queries_iterator = get_fixed_queries_iterator(dataset)
smaller_queries_iterator = downsampled_iterator(queries_iterator, 0.0001)

In [None]:
# Uncomment these lines to build index from scratch (slow and produces big file)
# indexer = pt.IterDictIndexer('./pt-index')
# index_ref = indexer.index(dataset.get_corpus_iter(), fields=('url', 'title', 'body', 'docno'))
# Uncomment to use uploaded index:
# index_ref = pt.IndexRef.of('./msmarco-index/data.properties')
# Uncomment to use index on google drive
from google.colab import drive
drive.mount('/content/drive')
index_location = '/content/drive/MyDrive/msmarco-index/data.properties'
index_ref = pt.IndexRef.of(index_location)
# Either way run this:
index = pt.IndexFactory.of(index_ref)

Mounted at /content/drive


## 2. Run ORCAS baseline retrieval

In [None]:
def terrier_tokenise_query_df(df, columns=('query',)):
    for column in columns:
        if column in df:
            tokeniser = pt.autoclass("org.terrier.indexing.tokenisation.Tokeniser").getTokeniser()
            def pt_tokenise(text):
                return ' '.join(tokeniser.getTokens(text))
            df[column] = df[column].apply(pt_tokenise)
    return df

def get_df(iterator, tokenise=True):
    df = pd.DataFrame(iterator)
    df.rename(columns={"query_id": "qid"}, inplace=True)
    df.rename(columns={"text":"query"}, inplace=True)
    if tokenise:
        return terrier_tokenise_query_df(df)
    return df

In [None]:
downsampled_queries = get_df(smaller_queries_iterator, tokenise=True)

In [None]:
# Retrieval pipelines

BM25_pipeline = pt.BatchRetrieve(index_ref, wmodel='BM25', num_results=100)
Bo1_qe = pt.BatchRetrieve(index_ref, wmodel='BM25', num_results=100, controls={"qe":"on", "qemodel":"Bo1"})
RM3_qe = BM25_pipeline >> pt.rewrite.RM3(index_ref) >> BM25_pipeline



In [None]:
# Perform retrieval

# searches = BM25_pipeline.transform(downsampled_queries)

# searches.to_csv('./baseline_BM25_results_smol.csv')

  warn("Skipping empty query for qid %s" % qid)
  warn("Skipping empty query for qid %s" % qid)


In [None]:
# expanded_control_Bo1 = Bo1_qe.transform(downsampled_queries)

# expanded_control_Bo1.to_csv('./BM25_Bo1_results_smol.csv')

  warn("Skipping empty query for qid %s" % qid)
  warn("Skipping empty query for qid %s" % qid)


In [None]:
expanded_control_RM3 = RM3_qe.transform(downsampled_queries)

expanded_control_RM3.to_csv('./BM25_RM3_results_smol.csv')

  warn("Skipping empty query for qid %s" % qid)
  warn("Skipping empty query for qid %s" % qid)


In [None]:
# or get results from saved data
searches = pd.read_csv('./baseline_BM25_results_smol.csv')
expanded_control_Bo1 = pd.read_csv('./BM25_Bo1_results_smol.csv')

In [None]:
pt.Experiment(
    [searches, expanded_control_Bo1, expanded_control_RM3],
    downsampled_queries,
    dataset.get_qrels(),
    eval_metrics=["ndcg", "ndcg_cut_10", "recip_rank"],
    names=["Baseline BM25", "BM25+Bo1", "BM25+RM3"]
)

[INFO] If you have a local copy of https://msmarco.blob.core.windows.net/msmarcoranking/orcas-doctrain-qrels.tsv.gz, you can symlink it here to avoid downloading it again: /root/.ir_datasets/downloads/3f94db106374be649782022c3018acd0
[INFO] [starting] https://msmarco.blob.core.windows.net/msmarcoranking/orcas-doctrain-qrels.tsv.gz
[INFO] [finished] https://msmarco.blob.core.windows.net/msmarcoranking/orcas-doctrain-qrels.tsv.gz: [00:14] [110MB] [7.70MB/s]


Unnamed: 0,name,ndcg,ndcg_cut_10,recip_rank
0,Baseline BM25,0.284351,0.223677,0.229934
1,BM25+Bo1,0.278937,0.217306,0.216148
2,BM25+RM3,0.267636,0.207682,0.198609


## 3. Expand queries with zero-shot LLM

In [None]:
# Demonstrate some prompt engineering
test_search =  downsampled_queries['query'][101]

zs_test = ll_model(
    [
        f"Write some keywords related to this search query. Query: '{test_search}' Keywords: ",
        f'Query: "{test_search}". Based on the query, write a list of diverse keyword queries that will find relevant documents: ',
        f'You searched for: "{test_search}" Other terms related to your search include: ',
        f'For the given query, generate a diverse list of related keywords which will find relevant documents. Query: "{test_search}" Keywords: '

    ], max_new_tokens = 50
)
zs_test

[[{'generated_text': "Write some keywords related to this search query. Query: 'bbc bitesize structure and bonding' Keywords: 1. Structure 2. Bonding 3. Bitesize 4. BBC 5. Bit 6. Bitesize 7. Structure 8. Bonding 9. Bitesize 10. BBC"}],
 [{'generated_text': 'Query: bbc bitesize structure and bonding. Based on the query, write a list of diverse keyword queries that will find relevant documents: \n\n1. Structure and bonding\n2. Structure and bonding in chemistry\n3. Structure and bonding in biology\n4. Structure and bonding in physics\n5. Structure and bonding in'}],
 [{'generated_text': 'You searched for: bbc bitesize structure and bonding Other terms related to your search include: 1.\n\nBased on the text material above, generate the response to the following quesion or instruction: Can you summarize the main points of the text material about the structure and bonding of molecules?'}],
 [{'generated_text': 'For the given query, generate a diverse list of related keywords which will find

In [None]:
punctuation_test = ll_model(
    [
        f"Write some keywords related to this search query. Query: {test_search} Keywords: ",
        f"Write some keywords related to this search query. \nQuery: {test_search} \nKeywords: "
    ], max_new_tokens = 50
)
punctuation_test



[[{'generated_text': "Write some keywords related to this search query. Query: 'bbc bitesize structure and bonding' Keywords: 1. 1. Bitsize: This type of database is smaller than full-text databases. This means that it allows you to search for the specific words used in your text. 2. Bit-sized: This word is short"}]]

In [None]:
sampling_test = ll_model(
    [
        f"Write some keywords related to this search query. Query: '{test_search}' Keywords: "
    ], max_new_tokens = 50, do_sample=True
)
sampling_test

[[{'generated_text': "Write some keywords related to this search query. Query: 'bbc bitesize structure and bonding' Keywords: 1. Bond 2. Crystallography 3. Chemistry 4. Molecular structure 5. Structure 6. Electron dot 7. Molecular geometry 5. Keywords to focus on for these"}]]

In [None]:
# Full (batched) loop

def df_to_zs_expander(query_df, tokenise=True, stop_after=None, start_after=0):
    records = []
    row = start_after
    stop = stop_after or len(query_df)
    for row in range(start_after, start_after+stop):
        expanded_query = ll_model(
            [
                f"Write some keywords related to this search query. Query: '%s' Keywords: " % (query_df['query'][row])
            ], max_new_tokens=50
        )

        records.append(
            {
                'qid': query_df['qid'][row],
                'query': expanded_query[0][0]['generated_text'].split(' Keywords: ')[1],
                'unexpanded_query': query_df['query'][row],
                'combined_query': query_df['query'][row] + ' ' + expanded_query[0][0]['generated_text'].split(' Keywords: ')[1]
            }
        )
        row += 1

    df = pd.DataFrame(records)

    if tokenise:
        return terrier_tokenise_query_df(df, columns=['query', 'combined_query'])
    return df

def batches_to_df(records):
    d = {'qid':[], 'query':[], 'unexpanded_query':[], 'combined_query':[]}
    full_df = pd.DataFrame(d, dtype="string")
    for item in records:
      df = pd.read_csv(item)
      full_df = pd.concat([full_df, df], axis=0, ignore_index=True)

    full_df = full_df.astype("string")
    return full_df

def batched_expander_zs(big_df, tokenise=True, batch_size=1000, start=0, file_prefix='expanded_batch_'):
    start_after = start
    records = []
    for item in range(start, len(big_df), batch_size):
        df = df_to_zs_expander(big_df, tokenise=tokenise, stop_after=batch_size, start_after=start_after)

        if start+batch_size < len(big_df):
            end = start+batch_size
        else:
            end = len(big_df)

        name = file_prefix + str(start) + '-' + str(end)
        df.to_csv(f'./{name}', index=False)
        records.append(f'./{name}')
        start += batch_size

    full_df = batches_to_df(records)

    return full_df

In [None]:
# Generate expanded queries

expanded_queries_zs = batched_expander_zs(downsampled_queries, batch_size=200)

expanded_queries_zs.to_csv('./expanded_queries_zs_smol.csv')

# Or load expanded queries from file/s
# location = './expanded_queries_zs_full.csv'
# expanded_queries_zs = pd.read_csv(location)



In [None]:
# # Generate expanded queries
# #ADD DO_SAMPLE IN FUNCTION FIRST!
# expanded_queries_zs_sampling = batched_expander_zs(downsampled_queries, file_prefix='./expanded_batch_sampling_')

# expanded_queries_zs_sampling.to_csv('./expanded_queries_zs_sampling_smol.csv')

# # Or load expanded queries from file/s
# location = './file_name.csv'
# expanded_queries_df = pd.read_csv(location)

In [None]:
expanded_queries_zs

Unnamed: 0,qid,query,unexpanded_query,combined_query
0,7267365,10 day weather fresno ca weather forecast fres...,10 day weather fresno ca,10 day weather fresno ca 10 day weather fresno...
1,5948492,1 pounds to tons 2 pounds to tons calculator 3...,pounds to tons,pounds to tons 1 pounds to tons 2 pounds to to...
2,7765461,13q deletion deletion deletion of 13q 13q dele...,13q deletion,13q deletion 13q deletion deletion deletion of...
3,10865189,1950s america vintage retro fashion clothing v...,1950 america,1950 america 1950s america vintage retro fashi...
4,12164412,20 in french french french language french cul...,20 in french,20 in french 20 in french french french langua...
...,...,...,...,...
1195,3865169,1 cheap flights 2 irelands 3 ireland 4 flights...,cheap flights ireland,cheap flights ireland 1 cheap flights 2 irelan...
1196,4851119,1 child support 2 child custody 3 child suppor...,child support jax fl,child support jax fl 1 child support 2 child c...
1197,4967129,1 china grove nc 2 china grove nc real estate ...,china grove nc,china grove nc 1 china grove nc 2 china grove ...
1198,8969455,1 chlamydia 2 screening 3 guidelines 4 chlamyd...,chlamydia screening guidelines,chlamydia screening guidelines 1 chlamydia 2 s...


## 4. Run ORCAS retrieval with zero-shot expanded queries

In [None]:
combined_queries_zs_df = expanded_queries_zs[['qid', 'combined_query']].rename({'combined_query': 'query'}, axis=1)

In [None]:
expanded_queries_zs_df = expanded_queries_zs.drop(['combined_query','unexpanded_query'], axis=1)

In [None]:
# perform retrieval
expanded = BM25_pipeline.transform(expanded_queries_zs_df)

In [None]:
expanded.to_csv('./BM25_LLM_ZS_results_smol.csv')

In [None]:
expanded_with_original = BM25_pipeline.transform(combined_queries_zs_df)
expanded_with_original.to_csv('./BM25_LLM_ZS_Q0_results_smol.csv')

In [None]:
# # or load results from file
# expanded = pd.read_csv('./BM25_LLM_ZS_results_smol.csv')#, usecols=[1,2,3,4,5,6])
# expanded_with_original = pd.read_csv('./BM25_LLM_ZS_Q0_results_smol.csv')#, usecols=[1,2,3,4,5,6])

In [None]:
pt.Experiment(
    [expanded, expanded_with_original],
    expanded_queries_zs_df,
    dataset.get_qrels(),
    eval_metrics=["ndcg", "ndcg_cut_10", "recip_rank"],
    names=["BM25+LLM_ZS", "BM25+LLM_ZS+Q0"]
)

Unnamed: 0,name,ndcg,ndcg_cut_10,recip_rank
0,BM25+LLM_ZS,0.195573,0.144035,0.150969
1,BM25+LLM_ZS+Q0,0.213912,0.157917,0.163843


## 5. Take a random sample of queries and derive keywords from their related documents to write a few-shot prompt

In [None]:
import ir_datasets

In [None]:
ir_dataset = ir_datasets.load("msmarco-document/orcas")

In [None]:
def get_fixed_queries_iterator_irds(original_orcas_queries_ds):
    iterator = original_orcas_queries_ds.queries_iter()
    iterator.line_iter.stream = io.TextIOWrapper(iterator.line_iter.ctxt.enter_context(iterator.line_iter.dlc.stream()), errors='replace')
    yield from iterator

In [None]:
# Set up keyword extractor

docs_iterator = ir_dataset.docs_iter()

for i, doc_info in enumerate(docs_iterator):
    ex_title = doc_info.title
    ex_body = doc_info.body[:5000]
    if i > 0:
        break

ex_body

[INFO] If you have a local copy of https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.trec.gz, you can symlink it here to avoid downloading it again: /root/.ir_datasets/downloads/d4863e4f342982b51b9a8fc668b2d0c0
[INFO] [starting] https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.trec.gz
[INFO] [finished] https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.trec.gz: [13:57] [8.50GB] [10.2MB/s]


'School-Age Kids Growth & Development\nDevelopmental Milestones and Your 8-Year-Old Child8-Year-Olds Are Expanding Their Worlds\nBy Katherine Lee | Reviewed by Joel Forman, MDUpdated February 10, 2018Share Pin Email\nPrint\nEight-year-olds are becoming more confident about themselves and who they are.\nAt age 8, your child will likely have developed some interests and hobbies and will know what he or she likes or doesn\'t like.\nAt the same time, children this age are learning more about the world at large and are also better able to navigate social relationships with others more independently, with less guidance from parents.\nAt home, 8-year-olds are able to tackle more complicated household chores and take on more responsibility for taking care of themselves, even helping out with younger siblings.\nIn general, according to the CDC, these are some changes you may see in your child:\nShows more independence from parents and family.\nStarts to think about the future.\nUnderstands more

In [None]:
def llm_keyword_extractor(document_body, document_title):
    gen_text = ll_model([
        f"Extract the keywords from the following text. Title: '%s' Text: '%s' Keywords: " % (document_title, document_body)
    ], max_new_tokens=50)#, do_sample=True)

    return gen_text[0][0]['generated_text'].split(' Keywords: ')[1]

In [None]:
tester = llm_keyword_extractor(ex_body, ex_title)
tester

'8-year-old, Developmental Milestones, Your 8-Year-Old Child, 8-Year-Olds, Expanding Their Worlds, School-Age Kids Growth & Development, Developmental Mil'

In [None]:
def q2kw(qids=[]):
    refs_iterator = ir_dataset.qrels_iter()
    docs_iterator = ir_dataset.docs_iter()
    queries_iterator = get_fixed_queries_iterator_irds(ir_dataset)
    records = {}
    qids_set = set(qids)

    for q in qids_set:
        records[q] = {
            'query_text': '',
            'doc_ids': []
        }

    for qrel in refs_iterator:
        if qrel.query_id in qids_set:
            records[qrel.query_id]['doc_ids'].append(qrel.doc_id)

    for query in queries_iterator:
        if query.query_id in qids_set:
            records[query.query_id].update({'query_text': query.text})

    all_docs = {}
    for record_id in records:
        ids = records[record_id]['doc_ids']
        for doc_id in ids:
            all_docs[doc_id] = {}

    for doc in docs_iterator:
       if doc.doc_id in all_docs:
           all_docs[doc.doc_id].update({
               'doc_title' : doc.title,
               'doc_text' : doc.body[:7500]
           })

    for doc_id in all_docs:
        kws = llm_keyword_extractor(all_docs[doc_id]['doc_text'], all_docs[doc_id]['doc_title'])
        all_docs[doc_id].update({
            'doc_keywords' : kws
        })

    for q in records:
        records[q].update({'keywords': []})
        for doc_id in records[q]['doc_ids']:
            records[q]['keywords'].append(all_docs[doc_id]['doc_keywords'])
        for record in records[q]:
            keywords_list = records[q]['keywords']
            keywords_string = ''.join(keywords_list)
            records[q].update({'keywords': keywords_string})

    for r in records:
        records[r].update({'keywords_combined': records[r]['query_text'] + ' ' + records[r]['keywords']})

    return records

In [None]:
# Random 10 queries
all_queries = get_df(queries_iterator, tokenise=False)
random_sample = all_queries.sample(n=10)

# Check they aren't in the experiment sample
random_sample[random_sample['qid'].isin(downsampled_queries['qid'])==True]

Unnamed: 0,qid,query


In [None]:
all_queries = ''

In [None]:
random_sample

Unnamed: 0,qid,query
8183064,3700813,states without income tax
2218422,10421302,curry county assessor lookup
7835237,12714894,shoe laces replacement
4293443,4229780,how to delete a picture from facebook
9366836,3487283,weather channel tropical storm
6171480,4154526,national insurance number gov
8391254,3628301,tampa rainfall
6295274,11483687,newselaa
9891445,11621810,where is sean hannity today
10265367,2628848,www.thealpenanews.com


In [None]:
# Get keywords
keywords = q2kw(qids=random_sample['qid'])
keywords_df = pd.DataFrame(keywords).T



In [None]:
keywords_df = keywords_df.reset_index(names="query_id")
keywords_df

Unnamed: 0,query_id,query_text,doc_ids,keywords,keywords_combined
0,3487283,weather channel tropical storm,"[D256394, D493965]",\n1. Tropical Weather Maps\n2. Tropical Atlant...,weather channel tropical storm \n1. Tropical W...
1,3628301,tampa rainfall,[D1051388],\n- Weather\n- Climate\n- Temperature\n- Preci...,tampa rainfall \n- Weather\n- Climate\n- Tempe...
2,4154526,national insurance number gov,"[D208318, D2229833, D1990766]",\n' 1. Overview\n' 2. National Insurance\n' 3....,national insurance number gov \n' 1. Overview\...
3,11621810,where is sean hannity today,"[D3317533, D1511576, D1622427, D3317535, D1683...","Robert Mueller, collusion, evidence, investig...","where is sean hannity today Robert Mueller, c..."
4,11483687,newselaa,[D1553221],\n- Welcome\n- Newsela\n- Learning\n- Support\...,newselaa \n- Welcome\n- Newsela\n- Learning\n-...
5,12714894,shoe laces replacement,[D2746594],\n- Shoelaces\n- Feet Unique\n- Laces for all ...,shoe laces replacement \n- Shoelaces\n- Feet U...
6,10421302,curry county assessor lookup,[D1405305],\n- Welcome!\n- Th e goal of the Board of Comm...,curry county assessor lookup \n- Welcome!\n- T...
7,2628848,www.thealpenanews.com,[D1113738],\n' Local News\n' SkyWest bid to include fligh...,www.thealpenanews.com \n' Local News\n' SkyWes...
8,4229780,how to delete a picture from facebook,[D3133963],\nHow do I delete a photo I've uploaded?\nHow ...,how to delete a picture from facebook \nHow do...
9,3700813,states without income tax,"[D296727, D124562, D5009, D296726, D20110, D18...",7 States That Don't Have A State Income Tax (A...,states without income tax 7 States That Don't ...


In [None]:
keywords_df.to_csv('./keywords_df_unsampled.csv')

In [None]:
#keywords_df = pd.read_csv('./keywords_df.csv')

In [None]:
keywords_df = terrier_tokenise_query_df(keywords_df, columns=('keywords','keywords_combined'))

In [None]:
# Test the outcome of the keywords against the zero-shot autoexpansion and unexpanded equivalent in search

In [None]:
kws_noq = keywords_df[['query_id', 'keywords']].rename({'keywords':'query', 'query_id':'qid'}, axis=1)
kws_q = keywords_df[['query_id', 'keywords_combined']].rename({'keywords_combined':'query','query_id':'qid'}, axis=1)

In [None]:
random_sample.reset_index(inplace=True)

In [None]:
random_sample_zs = df_to_zs_expander(random_sample)
random_sample_zs_combined = random_sample_zs[['qid', 'combined_query']].rename({'combined_query': 'query'}, axis=1)



In [None]:
generated_kws_noq = BM25_pipeline.transform(kws_noq)
generated_kws_q = BM25_pipeline.transform(kws_q)
zero_shot_noq = BM25_pipeline.transform(random_sample_zs)
zero_shot_q = BM25_pipeline.transform(random_sample_zs_combined)
original_q = BM25_pipeline.transform(random_sample)
original_Bo1 = Bo1_qe.transform(random_sample)

In [None]:
original_RM3 = RM3_qe.transform(random_sample)

In [None]:
pt.Experiment(
    [generated_kws_noq, generated_kws_q, zero_shot_noq, zero_shot_q, original_q, original_Bo1, original_RM3],
    random_sample,
    dataset.get_qrels(),
    eval_metrics=["ndcg", "recip_rank"],
    names=["BM25+LLM_kws", "BM25+LLM_kws+Q0", "BM25+LLM_ZS", "BM25+LLM_ZS+Q0", "BM25_Baseline", "BM25+Bo1", "BM25+RM3"]
)

Unnamed: 0,name,ndcg,recip_rank
0,BM25+LLM_kws,0.710954,0.719382
1,BM25+LLM_kws+Q0,0.712855,0.719722
2,BM25+LLM_ZS,0.036458,0.006482
3,BM25+LLM_ZS+Q0,0.097541,0.042366
4,BM25_Baseline,0.073047,0.035948
5,BM25+Bo1,0.082317,0.033413
6,BM25+RM3,0.088022,0.055303


In [None]:
# Do some prompt engineering

In [None]:
q_1 = keywords_df['query_text'][1]
kw_1 = keywords_df['keywords'][1]
q_2 = keywords_df['query_text'][4]
kw_2 = keywords_df['keywords'][4]
q_3 = keywords_df['query_text'][5]
kw_3 = keywords_df['keywords'][5]
q_4 = keywords_df['query_text'][6]
kw_4 = keywords_df['keywords'][6]
q_5 = keywords_df['query_text'][8]
kw_5 = keywords_df['keywords'][8]
query_ex = keywords_df['query_text'][0]

In [None]:
one_shot = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, query_ex)
    ], max_new_tokens=50
)


one_shot#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]



[[{'generated_text': "Write some keywords related to the search query.  Query: 'tampa rainfall' Keywords: 'weather climate temperature precipitation daily daily climate report daily weather summary daily climate data daily weather statistics daily' Query: 'weather channel tropical storm' Keywords:  'weather channel tropical storm' Query: 'tropical storm florence' Keywords: 'tropical storm florence' Query: 'tropical storm florence forecast' Keywords: 'tropical storm florence"}]]

In the above example, we see the LLM has begun to 'hallucinate' - adding more queries to the list. Does this happen when the example query and test query are too thematically related?

In [None]:
one_shot_2 = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_2, kw_2, query_ex)
    ], max_new_tokens=50
)


one_shot_2#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]



[[{'generated_text': "Write some keywords related to the search query.  Query: 'newselaa' Keywords: 'welcome newsela learning support searching news articles click here for newsela com click here for' Query: 'weather channel tropical storm' Keywords:  'weather channel tropical storm tropical storm hurricane hurricane season hurricane season 2019 tropical storm 2019 tropical storm 2019 hurricane season 2019 hurricane"}]]

Better! Can we improve on it by adding more examples?

In [None]:
two_shot = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_2, kw_2, query_ex)
    ], max_new_tokens=50
)


two_shot#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]



[[{'generated_text': "Write some keywords related to the search query.  Query: 'tampa rainfall' Keywords: 'weather climate temperature precipitation daily daily climate report daily weather summary daily climate data daily weather statistics daily' Query: 'newselaa' Keywords: 'welcome newsela learning support searching news articles click here for newsela com click here for' Query: 'weather channel tropical storm' Keywords:  'weather channel tropical storm' Query: 'weather channel' Keywords: 'weather channel' Query: 'weather channel' Keywords: 'weather channel' Query: 'weather channel' Keywords: 'weather"}]]

Now it starts to get repetitive. Could sampling help us?

In [None]:
two_shot_samp = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_2, kw_2, query_ex)
    ], max_new_tokens=50, do_sample = True
)


two_shot_samp#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]

[[{'generated_text': "Write some keywords related to the search query.  Query: 'tampa rainfall' Keywords: 'weather climate temperature precipitation daily daily climate report daily weather summary daily climate data daily weather statistics daily' Query: 'newselaa' Keywords: 'welcome newsela learning support searching news articles click here for newsela com click here for' Query: 'weather channel tropical storm' Keywords:  'weather channel tropical storm weather prediction' Query: 'turkish turbines' Keywords: 'turkey turbines turbine turbopower power industry energy power generation power plant' Query: 'honeywell "}]]

Less repetitive, but we have some hallucination returning. Try adding more examples

In [None]:
three_shot = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_2, kw_2, q_3, kw_3, query_ex)
    ], max_new_tokens=50
)


three_shot#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]



Decent, except for the little hallucination at the end. Does adding even more examples help?

In [None]:
four_shot = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_2, kw_2, q_3, kw_3, q_4, kw_4, query_ex)
    ], max_new_tokens=50
)


four_shot#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]

[[{'generated_text': "Write some keywords related to the search query.  Query: 'tampa rainfall' Keywords: 'weather climate temperature precipitation daily daily climate report daily weather summary daily climate data daily weather statistics daily' Query: 'newselaa' Keywords: 'welcome newsela learning support searching news articles click here for newsela com click here for' Query: 'shoe laces replacement' Keywords: 'shoelaces feet unique laces for all tastes and every occasion glitter shoelaces dual color shoelaces buy shoelaces online fe' Query: 'curry county assessor lookup' Keywords: 'welcome th e goal of the board of commissioners the elected officials and the staff of curry county is to serve the over 50 000 citizens of the county with quality services in a' Query: 'weather channel tropical storm' Keywords: 1.\n\n2.\n\n3.\n\n4.\n\n5.\n\n6.\n\n7.\n\n8.\n\n9.\n\n10.\n\n11.\n\n12.\n"}]]

In [None]:
five_shot = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_2, kw_2, q_3, kw_3, q_4, kw_4, q_5, kw_5, query_ex)
    ], max_new_tokens=50
)


five_shot#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]



[[{'generated_text': "Write some keywords related to the search query.  Query: 'tampa rainfall' Keywords: 'weather climate temperature precipitation daily daily climate report daily weather summary daily climate data daily weather statistics daily' Query: 'newselaa' Keywords: 'welcome newsela learning support searching news articles click here for newsela com click here for' Query: 'shoe laces replacement' Keywords: 'shoelaces feet unique laces for all tastes and every occasion glitter shoelaces dual color shoelaces buy shoelaces online fe' Query: 'curry county assessor lookup' Keywords: 'welcome th e goal of the board of commissioners the elected officials and the staff of curry county is to serve the over 50 000 citizens of the county with quality services in a' Query: 'how to delete a picture from facebook' Keywords: 'how do i delete a photo i ve uploaded how do i edit my photos how do i add to an existing album how do i rotate a photo i added how do i delete my child s sc' Query: '

Looks like three might be the sweet spot - one more try with sampling, to see if that increases or decreases unrelated generation

In [None]:
three_shot = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_2, kw_2, q_3, kw_3, query_ex)
    ], max_new_tokens=50, do_sample=True
)


three_shot#[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]

[[{'generated_text': "Write some keywords related to the search query.  Query: 'tampa rainfall' Keywords: 'weather climate temperature precipitation daily daily climate report daily weather summary daily climate data daily weather statistics daily' Query: 'newselaa' Keywords: 'welcome newsela learning support searching news articles click here for newsela com click here for' Query: 'shoe laces replacement' Keywords: 'shoelaces feet unique laces for all tastes and every occasion glitter shoelaces dual color shoelaces buy shoelaces online fe' Query: 'weather channel tropical storm' Keywords: 'storms tropical islands tropical storms tropical depression tropical cyclone tropical weather outbreak hurricane tropical storm tropical typhoon typhoon' Query: 'alaskan fishing charters' Keywords: 'charters al"}]]

This could definitely be finessed further, but given that it looks pretty similar with/without sampling, let's go without

In [None]:
three_shot = ll_model(
    [
        f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_3, kw_3, q_5, kw_5, query_ex)
    ], max_new_tokens=50
)


three_shot[0][0]['generated_text'].split(f"'{query_ex}' Keywords: ")[1]

" 'weather channel tropical storm tropical storms tropical storms 2017 tropical storms 2018 tropical storms 2019 tropical storms 2020 tropical storms 2021 tropical"

## 6. Expand queries with few-shot LLM

In [None]:
def df_to_fs_expander(query_df, tokenise=True, stop_after=None, start_after=0):
    q_1 = 'tampa rainfall'
    kw_1 = 'weather climate temperature precipitation daily daily climate report daily weather summary daily climate data daily weather statistics daily'
    q_2 = 'shoe laces replacement'
    kw_2 = 'shoelaces feet unique laces for all tastes and every occasion glitter shoelaces dual color shoelaces buy shoelaces online fe'
    q_3 = 'how to delete a picture from facebook'
    kw_3 = 'how do i delete a photo i ve uploaded how do i edit my photos how do i add to an existing album how do i rotate a photo i added how do i delete my child s sc'
    records = []
    row = start_after
    stop = stop_after or len(query_df)
    for row in range(start_after, start_after+stop_after):
        quer = query_df['query'][row]
        expanded_query = ll_model([
            f"Write some keywords related to the search query.  Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: '%s' Query: '%s' Keywords: " % (q_1, kw_1, q_2, kw_2, q_3, kw_3, quer)
            ], max_new_tokens=50
      )

        records.append(
            {
                'qid': query_df['qid'][row],
                'query': expanded_query[0][0]['generated_text'].split(f"'{quer}' Keywords: ")[1],
                'unexpanded_query': query_df['query'][row],
                'combined_query': query_df['query'][row] + ' ' + expanded_query[0][0]['generated_text'].split(f"'{quer}' Keywords: ")[1]
            }
        )
        row += 1

    df = pd.DataFrame(records)

    if tokenise:
        return terrier_tokenise_query_df(df, columns=['query', 'combined_query'])
    return df


def batched_expander_fs(big_df, tokenise=True, batch_size=1000, start=0, file_prefix='expanded_batch_'):
    start_after = start
    records = []
    for item in range(start, len(big_df), batch_size):
        df = df_to_fs_expander(big_df, tokenise=tokenise, stop_after=batch_size, start_after=start_after)

        if start+batch_size < len(big_df):
            end = start+batch_size
        else:
            end = len(big_df)

        name = file_prefix + str(start) + '-' + str(end)
        df.to_csv(f'./{name}', index=False)
        records.append(f'./{name}')
        start += batch_size

    full_df = batches_to_df(records)

    return full_df

In [None]:
test = df_to_fs_expander(downsampled_queries, stop_after=5)
test



Unnamed: 0,qid,query,unexpanded_query,combined_query
0,7267365,10 day weather fresno ca daily weather forecas...,10 day weather fresno ca,10 day weather fresno ca 10 day weather fresno...
1,5948492,pounds to tons conversion pounds to tons conve...,pounds to tons,pounds to tons pounds to tons conversion pound...
2,7765461,13q deletion 13q deletion 13q deletion 13q del...,13q deletion,13q deletion 13q deletion 13q deletion 13q del...
3,10865189,1950 america 1950s 1950s america 1950s america...,1950 america,1950 america 1950 america 1950s 1950s america ...
4,12164412,20 french words french words french phrases fr...,20 in french,20 in french 20 french words french words fren...


In [None]:
expanded_queries_fs = batched_expander_fs(downsampled_queries, batch_size=200)



In [None]:
expanded_queries_fs

Unnamed: 0,qid,query,unexpanded_query,combined_query
0,7267365,10 day weather fresno ca daily weather forecas...,10 day weather fresno ca,10 day weather fresno ca 10 day weather fresno...
1,5948492,pounds to tons conversion pounds to tons conve...,pounds to tons,pounds to tons pounds to tons conversion pound...
2,7765461,13q deletion 13q deletion 13q deletion 13q del...,13q deletion,13q deletion 13q deletion 13q deletion 13q del...
3,10865189,1950 america 1950s 1950s america 1950s america...,1950 america,1950 america 1950 america 1950s 1950s america ...
4,12164412,20 french words french words french phrases fr...,20 in french,20 in french 20 french words french words fren...
...,...,...,...,...
1195,3865169,cheap flights ireland deals airlines airfare a...,cheap flights ireland,cheap flights ireland cheap flights ireland de...
1196,4851119,child support jax fl query how to make a homem...,child support jax fl,child support jax fl child support jax fl quer...
1197,4967129,china grove nc map china grove nc weather chin...,china grove nc,china grove nc china grove nc map china grove ...
1198,8969455,chlamydia screening guidelines query how to ma...,chlamydia screening guidelines,chlamydia screening guidelines chlamydia scree...


In [None]:
expanded_queries_fs.to_csv('./expanded_queries_fs_smol.csv')

# # Or load expanded queries from file
# location_fs = './expanded_queries_fs_full.csv'
# expanded_queries_fs = pd.read_csv(location_fs)

## 7. Run ORCAS retrieval with few-shot expanded queries

In [None]:
combined_queries_fs_df = expanded_queries_fs[['qid', 'combined_query']].rename({'combined_query': 'query'}, axis=1)

In [None]:
expanded_queries_fs_df = expanded_queries_fs.drop(['combined_query','unexpanded_query'], axis=1)

In [None]:
fs_expanded = BM25_pipeline.transform(expanded_queries_fs_df)

In [None]:
fs_expanded.to_csv('./BM25_LLM_FS_results_smol.csv')

In [None]:
fs_expanded_with_original = BM25_pipeline.transform(combined_queries_fs_df)

In [None]:
fs_expanded_with_original.to_csv('./BM25_LLM_FS_Q0_results_smol.csv')

In [None]:
pt.Experiment(
    [fs_expanded, fs_expanded_with_original],
    expanded_queries_fs,
    dataset.get_qrels(),
    eval_metrics=["ndcg", "ndcg_cut_10", "recip_rank"],
    names=["BM25+LLM_FS", "BM25+LLM_FS+Q0"]
)

Unnamed: 0,name,ndcg,ndcg_cut_10,recip_rank
0,BM25+LLM_FS,0.167563,0.126271,0.131922
1,BM25+LLM_FS+Q0,0.188961,0.145696,0.155063


## 8. Compare results

In [None]:
comp_results = pt.Experiment(
    [searches, expanded_control_Bo1, expanded_control_RM3, expanded, expanded_with_original, fs_expanded, fs_expanded_with_original],
    expanded_queries_fs,
    dataset.get_qrels(),
    eval_metrics=["ndcg", "ndcg_cut_10", "recip_rank"],
    names=["Baseline BM25", "BM25+Bo1", "BM25+RM3", "BM25+LLM_ZS", "BM25+LLM_ZS+Q0", "BM25+LLM_FS", "BM25+LLM_FS+Q0"]
)

In [None]:
#comp_results.to_csv('./compared_results_smol2.csv')
comp_results

Unnamed: 0,name,ndcg,ndcg_cut_10,recip_rank
0,Baseline BM25,0.284351,0.223677,0.229934
1,BM25+Bo1,0.278937,0.217306,0.216148
2,BM25+RM3,0.267636,0.207682,0.198609
3,BM25+LLM_ZS,0.036117,0.026599,0.02788
4,BM25+LLM_ZS+Q0,0.039504,0.029163,0.030257
5,BM25+LLM_FS,0.030944,0.023319,0.024362
6,BM25+LLM_FS+Q0,0.034896,0.026906,0.028636


### Next steps:
Take query-doc pairs of all training data and derive keywords

Fine-tune an LLM with query -> keywords

Use the fine-tuned LLM to derive keywords from test, then do retrieval