In [None]:
!pip install transformers
!pip install gradio=='3.21.0'
!pip install sentencepiece
!pip install accelerate
!pip install python-terrier
!pip install coolname

In [2]:
import pyterrier as pt
from pyterrier.measures import * # don't uncomment this #ok
import os
import tqdm
# import nltk
# from nltk.corpus import stopwords
# from nltk.tokenize import word_tokenize
import re
import torch
import numpy as np
import pandas as pd
from datetime import datetime
import coolname
import random




def print_seeds():
    print(f'torch seed = {torch.seed()}')
    print(f'numpy seed = {np.random.seed()}')

In [None]:
if not pt.started():
    pt.init(boot_packages=["com.github.terrierteam:terrier-prf:-SNAPSHOT"])

In [4]:
def cleanup(s1):
    return "".join([x if x.isalnum() else " " for x in s1.strip()])

def get_index(EVALUATION_NAME, index_name, field=None, colbert=False,):
    if index_name == "vaswani":
        print(f"Loading Index {index_name}...")
        index_path = f'./indices/{index_name}'
        pt_index_path = index_path + '/data.properties'
        if not os.path.exists(pt_index_path):
            dataset = pt.get_dataset(EVALUATION_NAME)
            indexer = pt.IterDictIndexer(index_path)
            index_ref = indexer.index(dataset.get_corpus_iter(), fields=['text'])
        else:
            dataset = pt.get_dataset(EVALUATION_NAME)
            print('Using prebuilt index.')
            index_ref = pt.IndexRef.of(pt_index_path)
        index = pt.IndexFactory.of(index_ref)
        print('Completed indexing')
        if colbert:
            return index, dataset, dataset.get_topics(), dataset.get_corpus_iter
        queries = dataset.get_topics()
        queries['query'] = queries['query'].apply(cleanup)
        return index, dataset, queries
    if index_name == "beir_dbpedia-entity":
        print(f"Loading Index {index_name}...")
        index_path = f'./indices/{index_name}'
        pt_index_path = index_path + '/data.properties'
        if not os.path.exists(pt_index_path):
            dataset = pt.get_dataset(EVALUATION_NAME)
            indexer = pt.IterDictIndexer(index_path, meta={"docno": 200})
            index_ref = indexer.index(dataset.get_corpus_iter(), fields=['text', 'title', 'url'])
        else:
            dataset = pt.get_dataset(EVALUATION_NAME)
            print('Using prebuilt index.')
            index_ref = pt.IndexRef.of(pt_index_path)
        index = pt.IndexFactory.of(index_ref)
        queries = dataset.get_topics()
        queries['query'] = queries['query'].apply(cleanup)
        print('Completed indexing')
        if colbert:
            def corpus_iterator():
                for y in dataset.get_corpus_iter():
                    y['text'] = y['title'] + " " + y['text']
                    if y['text'].strip():
                        yield y
            return index, dataset, dataset.get_topics(), corpus_iterator
        return index, dataset, queries
    if index_name == "beir_webis-touche2020_v2":
        print(f"Loading Index {index_name}...")
        index_path = f'./indices/{index_name}'
        pt_index_path = index_path + '/data.properties'
        if not os.path.exists(pt_index_path):
            dataset = pt.get_dataset(EVALUATION_NAME)
            indexer = pt.IterDictIndexer(index_path, meta={"docno": 39})
            index_ref = indexer.index(dataset.get_corpus_iter(), fields=['text', 'title', 'stance', 'url'])
        else:
            dataset = pt.get_dataset(EVALUATION_NAME)
            print('Using prebuilt index.')
            index_ref = pt.IndexRef.of(pt_index_path)
        index = pt.IndexFactory.of(index_ref)
        print('Completed indexing')
        queries = dataset.get_topics()
        queries['query'] = queries['description'].str.cat(queries['text'], sep=' ')
        queries['query'] = queries['query'].apply(cleanup)
        if colbert:
            def corpus_iterator():
                for y in dataset.get_corpus_iter():
                    y['text'] = y['title'] + " " + y['text']
                    if y['text'].strip():
                        yield y
            return index, dataset, queries, corpus_iterator
        return index, dataset, queries
    elif index_name == "msmarco_passage":
        print(f"Loading Index {index_name}...")
        index_path = f'./indices/{index_name}'
        pt_index_path = index_path + '/data.properties'
        if not os.path.exists(pt_index_path):
            dataset = pt.get_dataset(EVALUATION_NAME)
            indexer = pt.IterDictIndexer(index_path)
            index_ref = indexer.index(dataset.get_corpus_iter(), fields=['text'])
        else:
            dataset = pt.get_dataset(EVALUATION_NAME)
            print('Using prebuilt index.')
            index_ref = pt.IndexRef.of(pt_index_path)
        index = pt.IndexFactory.of(index_ref)
        print('Completed indexing')
        if colbert:
            return index, dataset, dataset.get_topics(), dataset.get_corpus_iter
        queries = dataset.get_topics()
        queries['query'] = queries['query'].apply(cleanup)
        return index, dataset, queries
    elif index_name == "msmarco_document":
        print(f"Loading Index {index_name}...")
        index_path = f'./indices/{index_name}'
        pt_index_path = index_path + '/data.properties'
        if not os.path.exists(pt_index_path):
            dataset = pt.get_dataset(EVALUATION_NAME)
            indexer = pt.IterDictIndexer(index_path)
            index_ref = indexer.index(dataset.get_corpus_iter(), fields=['url', 'title', 'body'])
        else:
            dataset = pt.get_dataset(EVALUATION_NAME)
            print('Using prebuilt index.')
            index_ref = pt.IndexRef.of(index_path)
        index = pt.IndexFactory.of(index_ref)
        print('Completed indexing')
        queries = dataset.get_topics()
        queries['query'] = queries['query'].apply(cleanup)
        return index, dataset, queries
    elif index_name == "trec-covid":
        print(f"Loading Index {index_name}...")
        EVALUATION_NAME = "irds:cord19/trec-covid"
        index_name = "cord19/trec-covid"
        index_path = f'./indices/{index_name}'
        pt_index_path = index_path + '/data.properties'
        if not os.path.exists(pt_index_path):
            dataset = pt.get_dataset(EVALUATION_NAME)
            indexer = pt.IterDictIndexer(index_path)
            index_ref = indexer.index(dataset.get_corpus_iter(), fields=['title', 'doi', 'date', 'abstract'])
        else:
            dataset = pt.get_dataset(EVALUATION_NAME)
            print('Using prebuilt index.')
            index_ref = pt.IndexRef.of(index_path)
        index = pt.IndexFactory.of(index_ref)
        print('Completed indexing')
        queries = dataset.get_topics()
        queries['query'] = queries['title'].str.cat(queries['description'], sep=' ')
        queries['query'] = queries['query'].apply(lambda text: text.replace("?", ""))
        queries['query'] = queries['query'].apply(cleanup)
        if colbert:
            def corpus_iterator():
                for y in dataset.get_corpus_iter():
                    y['text'] = y['title'] + " " + y['abstract']
                    if y['text'].strip():
                        yield y
            return index, dataset, queries, corpus_iterator
        return index, dataset, queries
    else:
        print(f"KD:No index selected of name {index_name}.")
        return None

def get_bm25_pipe(index_name, index):
    if index_name in ["trec-covid", "msmarco_passage", "msmarco_document"]:
        bm25 = pt.BatchRetrieve.from_dataset(index_name, 'terrier_stemmed', wmodel='BM25')
        #bm25_10000 = pt.BatchRetrieve.from_dataset(index_name, 'terrier_stemmed', wmodel='BM25', num_results=10000)
    else:
        bm25 = pt.BatchRetrieve(index, wmodel='BM25')
        #bm25_10000 = pt.BatchRetrieve.from_dataset(index_name, 'terrier_stemmed', wmodel='BM25', num_results=10000)
    return bm25

triplets = [
['irds:msmarco-passage/trec-dl-2019/judged',  'msmarco_passage', 'text', 'text'],
["irds:beir/webis-touche2020/v2", "beir_webis-touche2020_v2", "text", "text"],
["irds:beir/dbpedia-entity/test", "beir_dbpedia-entity", 'text', 'text'],
["irds:vaswani", "vaswani", 'text', 'text']]

bm25= None; tfidf= None; docno2doctext = None
def on_dataset_change(dataset_name):
  triplet = [t for t in triplets if t[1]==dataset_name][0]
  EVALUATION_NAME = triplet[0]; index_name = triplet[1]; field = triplet[2]; doc_field = triplet[3]
  index, dataset, queries, corpus_iterator = get_index(EVALUATION_NAME, index_name, field, colbert=True)
  docno2doctext = {doc['docno']: doc[field] for doc in corpus_iterator()}
  bm25 = pt.BatchRetrieve(index, wmodel='BM25')
  tfidf = pt.BatchRetrieve(index, wmodel='TF_IDF')
  return bm25, tfidf, docno2doctext

def user_selects_different_index(index_id): # for dropdpwn
  triplet = triplets[index_id]
  return on_dataset_change(triplet[1])

In [5]:
import os
SAVED_DATA_DIRECTORY = "saved_data"
if not os.path.exists(SAVED_DATA_DIRECTORY):
  os.mkdir(SAVED_DATA_DIRECTORY)


In [None]:
dataset_names = [triplet[1] for triplet in triplets]
# default is id=3
bm25, tfidf, docno2doctext  = user_selects_different_index(3) # vaswani is the fastest to load. For testing use index_id = 3 (vaswani)
retrieval_algos_dict = {'BM25': bm25, 'TF_IDF': tfidf}
retrieval_algos_names = ['BM25','TF_IDF']

In [7]:
def get_doc_text(docno):
  if docno not in docno2doctext.keys():
    return f"Document Text not found for Document ID = {docno}"
  return docno2doctext[docno]

def retrieve_for_ui(query_text, pipeline):
  searchresults1 = (pipeline%10).search(cleanup(query_text))
  searchresults1['eng-text'] = searchresults1['docno'].apply(get_doc_text)
  searchresults1['target-text'] = searchresults1['eng-text']
  res = [row.to_dict() for index, row in searchresults1.iterrows()]
  return res

In [None]:
retrieve_for_ui('what is the capital of afghanistan', bm25)

In [9]:
pdd1 = retrieve_for_ui('some search query', bm25)

In [10]:
import gradio as gr
# for batch size = 1
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria,
    ForcedBOSTokenLogitsProcessor,
    HammingDiversityLogitsProcessor, MinLengthLogitsProcessor
)

def clean1( text):
    text = text.replace('<pad>', '')
    text = text.replace('</s>', '')
    text = text.strip().capitalize()
    if text.endswith('?'):
        return text
    else:
        return text + "?"

class DiverseGenerator(object):
    def __init__(self, forced_bos=True, hamming=True, model_name='castorini/doc2query-t5-base-msmarco', start_words=['What','When','Which','Where','How']):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.model.to(device)
        self.start_words = start_words
        self.forced_bos = forced_bos
        self.hamming = hamming
    def clean(self, text):
        return clean1(text)
    def diverse_generate(self, document, num_beams = 20, basic_beam_search=True, hamming=False):
        document_tokenized = self.tokenizer(document, return_tensors='pt')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f'dev:{device}')
        encoder_input_ids = document_tokenized['input_ids'].to(device)
        generations = []
        if basic_beam_search:
            generated_ids = self.model.generate(encoder_input_ids, max_length=128,
                                           pad_token_id=self.tokenizer.eos_token_id, num_beams=num_beams, num_return_sequences=5, temperature = 1.3,top_p=0.92, repetition_penalty =2.1, do_sample=True).tolist()
            preds = [self.clean(self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True, )) for g in generated_ids]
            generations.extend(preds)
        input_ids = torch.ones((num_beams, 1), device=self.model.device, dtype=torch.long)
        input_ids = input_ids * self.model.config.decoder_start_token_id
        if hamming:
            generations.extend(self.hamming_diverse(encoder_input_ids, input_ids,))
        return generations
    def hamming_diverse(self, encoder_input_ids, input_ids, num_beams = 6, num_beam_groups=3):
        model_kwargs = {
            "encoder_outputs": self.model.get_encoder()(
                encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
            )
        }
        # instantiate beam scorer
        beam_scorer = BeamSearchScorer(
            batch_size=1,
            max_length=self.model.config.max_length,
            num_beams=num_beams,
            num_beam_hyps_to_keep=num_beams,
            device=self.model.device,
            num_beam_groups=num_beam_groups,
        )
        logits_processor = LogitsProcessorList(
            [HammingDiversityLogitsProcessor(5.5, num_beams=num_beams, num_beam_groups=num_beam_groups),
             MinLengthLogitsProcessor(8, eos_token_id=self.model.config.eos_token_id),
             ]
        )
        outputs = self.model.group_beam_search(
            input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
        )
        generations = []
        for gen in self.tokenizer.batch_decode(outputs, skip_special_tokens=True):
            generations.append(self.clean(gen))
        #print(len(generations))
        return generations


In [11]:

def highlighted_doc(d):
    #print(f"D={str(d)}")
    document = d['eng-text']
    event_str = d['event_str']
    event_word_to_event = {e[0]['string']: str(e[0]['label']) for e in event_str}
    if len(event_word_to_event)==0:
        return []
    highlights = []
    for word in document.split(" "):
        if word in event_word_to_event.keys():
            event_type = event_word_to_event[word]
            highlights.extend([(word, event_type)])
        else:
            highlights.extend([(word, None)])
    return highlights

def lang_to_emoji(lang):
    if lang == "fas":
        return "🇮🇷 "
    if lang == "eng":
        return "🇺🇸 "
    if lang == 'kor':
        return "🇰🇷 "
    if lang == 'rus':
        return "🇷🇺 "
    if lang == 'zho':
        return "🇨🇳 "
    if lang =='ara':
        return '🇸🇦 '
    return "📜"

NDOCS = 10
max_text_lim = 400
## need to update this with retrival method choosing

def document_retriever(session, query_text,logs_file_name, model_choice_retrival=retrieval_algos_names[0], max_text_lim=400):
    if logs_file_name != "manual_search":
      on_query_change(query_text,logs_file_name, session)
    retrieval_algo = retrieval_algos_dict[model_choice_retrival]
    docs_all_lang = retrieve_for_ui(query_text, retrieval_algo)
    #print(f"docs_all_lang={docs_all_lang}")
    #print(f"len(docs_all_lang) = {len(docs_all_lang)}")
    #print(f"event_ann = {event_ann}")
    docs_to_ui = []
    docs_en = [d['eng-text']  for d in docs_all_lang]
    #print(f"docs_en[0] = " + docs_en[0])
    #print(f"len(docs_en) = {len(docs_en)}")
    docs_fas = [d['target-text'] for d in docs_all_lang]
    docs_score = [d['score'] for d in docs_all_lang]
    #print(f"docs_fas[0] = " + docs_fas[0])
    #docs_fas = [str(d['events']).strip()[:500].strip() + "..." for d in event_ann]
    docs_events = ["" for d in docs_en] # highlighted_doc(d)
    for i in range(len(docs_en)):
        #print(f"i={str(i)}")
        #docs_to_ui.append("Title: " + docs_title_dummy[i]) #lang_to_emoji(docs_fas[i]['lang'])+
        dtran = str(docs_fas[i]).strip()[:max_text_lim].strip() + "..."
        docs_to_ui.append(dtran)
        deng = "🇺🇸  "+str(docs_en[i]).strip()[:max_text_lim].strip() + "..."
        docs_to_ui.append(deng)
        docs_to_ui.append(docs_score[i])
        docs_to_ui.append(docs_events[i])
    for i in range(len(docs_en), NDOCS):
        docs_to_ui.append("")
        docs_to_ui.append("")
        docs_to_ui.append("")
        docs_to_ui.append("")
    if docs_to_ui is not None:
        if len(docs_to_ui) == NDOCS*4:
            return docs_to_ui
    return [""]*NDOCS*4

def add_to_new_query(current_text_box, to_add):
    return current_text_box + " " + to_add

eg1="Eleven people were killed in a train crash in northern Italy..."
eg2="Based on the U.S. Electronics Network (CNN) cited Peruvian..."
def sample_doc_examples(inshort):
    if inshort == eg1:
        return "Eleven people were killed in a train crash in northern Italy when the train crashed in the north of Italy The number of victims of the train disaster in northern Italy has grown to 11 people reported by the head of the administration of the autonomous province of Bolzano Louis Durnwalder."
    if inshort == eg2:
        return "Based on the U.S. Electronics Network (CNN) cited Peruvian officials as saying that it is currently known that 2 people have died and 65 have been injured and the death toll has risen to 1 person."

In [None]:
model_name0 = 'google/flan-t5-large'
diverseGenerator = DiverseGenerator(model_name=model_name0)

def update_model(model_name):
  if model_name is not model_name0:
    print(f"Updating query generators to {model_name}")
    diverseGenerator = DiverseGenerator(model_name=model_name0)

In [41]:
import random
# basic beam search
#this is what we are using now to generate queries, kd pls change it if you want in gradio
def query_generator1(prompt_instruction,prompt_context, prompt_input, model_name, session):
    input = prompt_instruction + "\n" + prompt_context+"\nDocument:"+prompt_input+"\nQuery:"
    queries = diverseGenerator.diverse_generate(input)
    queries = list(set(queries))
    #random.shuffle(queries)
    #queries.sort(key=len, reverse=True)
    if len(queries) >= 5:
        return session,*queries[0:5]
    else:
        ll = [""] * 5
        for (i,q) in enumerate(queries):
            ll[i] = q
        return session,*ll

def query_generatortemp(prompt_instruction,prompt_context, prompt_input, model_name, session):
    input = prompt_instruction + "\n" + prompt_context+"\nDocument:"+prompt_input+"\nQuery:"
    queries = diverseGenerator.diverse_generate(input)
    queries = list(set(queries))
    #random.shuffle(queries)
    #queries.sort(key=len, reverse=True)
    if len(queries) >= 5:
        return session,*queries[0:5]
    else:
        ll = [""] * 5
        for (i,q) in enumerate(queries):
            ll[i] = q
        return session,*ll
# hamming diversity beam search
def query_generator2(input,model_name):
    diverseGenerator = DiverseGenerator(model_name=model_name)
    queries = diverseGenerator.diverse_generate(input, basic_beam_search=False, hamming=True)
    queries = list(set(queries))
    #random.shuffle(queries)
    #queries.sort(key=len, reverse=True)
    if len(queries) >= 5:
        return queries[0:5]
    else:
        ll = [""] * 5
        for (i,q) in enumerate(queries):
            ll[i] = q
        return ll


In [14]:
import json
#saving docs
def save_document_auto(session, model_name,file_name,query, *document_texts):
    relevance_annotations = []
    for i in range(0, len(document_texts), 4):
        text1 = document_texts[i]
        if(text1):
            text2 = document_texts[i + 1]
            score = document_texts[i + 2]
            rating = document_texts[i + 3]
            if not (rating):
                rating = 0
            id = i // 4 + 1
            relevance_annotations.append({"id": id, "doc_test": text1, "translation": text2, "score":score,"rating": rating})
    final_data = {
        "session": session,
        "query": query,
        "model_choice":model_name,
        "relevance_annotations": relevance_annotations,
    }
    folder_path = 'saved_data'
    file_path = f'{folder_path}/{file_name}.json'
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
        with open(file_path, "r") as jsonfile:
            data = json.load(jsonfile)
            if isinstance(data, list):
                data.append(final_data)
            else:
                data = [data, final_data]
        with open(file_path, "w") as jsonfile:
            json.dump(data, jsonfile, indent=4)
    else:
        with open(file_path, "w") as jsonfile:
            json.dump([final_data], jsonfile, indent=4)

In [15]:
instructions = ["Generate a query which is relevant to a given document",
                "Generate a query which is relevant to a given document and is different from previously generated query",
                "Generate a document relevant query with different words from the document"]
instruction = instructions[0]

from transformers import T5Tokenizer, T5ForConditionalGeneration
def get_flant5xl():
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
    return tokenizer, model

def dqpairs(documents, queries):
  dqs = "\n".join([f"Document:{d}\nQuery:{q}" for (d,q) in zip(documents, queries)])
  #dqs = f"{instruction}:\n" + dqs
  return dqs

def dqprompt(documents, queries):
  all_except_last = documents[:-1]
  dqs = "\n".join([f"Document:{d}\nQuery:{q}" for (d,q) in zip(all_except_last, queries)])
  #dqs = f"{instruction}:\n" + dqs
  dqs = dqs + f"\nDocument:{documents[-1]}"
  return dqs

def get_dqlists_from_prompt(prompt_text):
  dqs = prompt_text.split('\n')
  instruction = dqs[0]
  ds = []
  qs = []
  for i in range(1, len(dqs) - 2, 2):
    ds.append(dqs[i].split(":")[1])
    qs.append(dqs[i+1].split(":")[1])
  last_doc_not_needed = dqs[i].split(":")[1]
  return ds, qs, instruction

def update_prompt(model_input, document):
    document = document.replace("🇺🇸","").strip()
    ds, qs, instruction = get_dqlists_from_prompt(model_input)
    # now keep last 2 doc-query pairs
    if len(ds) > 2:
        ds = ds[len(ds)-2:]
        qs = qs[len(ds)-2:]
    ds.append(document)
    return instruction +"\n"+ dqprompt(ds, qs) + "\nQuery:"

def send_feedback(input_query, document, model_choice_query, session, file_name="feedback_query_reformulations"):
  document = document.replace("🇺🇸","").strip()
  rf_queries = query_generator1(f"Based on the given context ```{document}```, generate keywords for the query : ","", input_query,model_choice_query, session)
  ref_query = input_query + " " + rf_queries[1]
  on_query_change(ref_query, file_name, session)
  return ref_query

def append_keywords(session, query, reform_method, reform_instruction, model_choice_query, file_name="query_reformulations"):
  rf_queries = query_generator1(f"Generate keywords for the query : ","", query, model_choice_query, session)
  ref_query = query + " " + rf_queries[1]
  on_query_change(ref_query, file_name, session)
  return ref_query

ds_eg = ["24 Aug 2016 At least 38 people died in a powerful earthquake that hit central Italy early on Wednesday 250 people are now known to have died in the earthquake that hit central Italy on Wednesday",
      "As of Thursday morning, the deaths totaled 241, officials said. 6.2 Earthquake In Central Italy, At Least 37 Dead"]
qs_eg = ["When did the earthquake happen?", "How many people died in the earthquake?"]

sample_docs = ["China’s state-run Citic Group, the main developer of the project, said negotiations were ongoing and that the $1.3bn was to be spent on the “initial phase” of the port, adding the project was divided into four phases. It did not elaborate on plans for subsequent stages. Xi and Mitsotakis will visit the port of Piraeus, Greece’s largest and majority owned by Chinese port operator Cosco. It is the biggest Chinese investment in Greece and Cosco recently received approval for a new investment plan that includes building a new cargo terminal.",
               "FireEye, one of the largest cyber security companies in the United States, said on Tuesday that it had been hacked, likely by a government, and that an arsenal of hacking tools used to test the defences of its clients had been stolen.  The hack of FireEye, a company with an array of contracts across the national security space both in the United States and its allies, is among the most significant breaches in recent memory.",
               "The alleged state-backed hacking groups engaging in these attacks include a group from Russia code-named ‘Strontium’ and two groups from North Korea code-named ‘Zinc’ and ‘Cerium’. "]


In [16]:
def display_inter_search_content(file_name):
    try:
        with open(f'saved_data/{file_name}', 'r') as file:
            data = json.load(file)
            content = ""
            for entry in data:
                if len(entry['relevance_annotations']) >1:
                    query = "Query: " + entry["query"]
                    content += f"<h3>{query}</h3>"
                    if entry["model_choice"]:
                        model_choice = "Model Choice: " + str(entry["model_choice"])
                        content += f"<h4>{model_choice}</h4>"
                    if entry["session"]:
                        session_choice = "Session_iD: " + str(entry["session"])
                        content += f"<h4>{session_choice}</h4>"
                    df = pd.DataFrame(entry["relevance_annotations"])
                    df.rename(columns={'doc_test': 'Doc_Test', 'translation': 'Translation', 'score': 'Score', 'rating': 'Rating'}, inplace=True)
                    content += df.to_html(border=0, index=False)
            return content
    except Exception as e:
        return f"Error: {str(e)}"

def read_jsonl(file_path):
    return pd.read_json(file_path, lines=True)

def display_jsonl_content(file_name, session_id):
    if file_name == 'interactive_search.json':
        return display_inter_search_content(file_name)
    try:
        df = read_jsonl(f'saved_data/{file_name}')
        if session_id and session_id != "Everyone":
            df = df[df['Session'] == session_id]
        content = ""
        column_names = df.columns
        for index, row in df.iterrows():
            for col in column_names:
                content += f"<p><b>{col.capitalize()}:</b> {row[col]}</p>"
            content += "<hr>"
        return content if content else "No data found for this session ID"
    except Exception as e:
        return f"Error: {str(e)}"


def list_files(directory):
    return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]


def get_unique_sessions(file_name):
    try:
        df = read_jsonl(f'{SAVED_DATA_DIRECTORY}/{file_name}')
        return df['Session'].unique().tolist()
    except Exception as e:
        return []


directory = f'{SAVED_DATA_DIRECTORY}/'


def generate_cool_name():
    name1 = coolname.generate_slug(2)+f'{random.randint(10, 99)}'
    return name1

In [17]:
def write_to_jsonl(file_name, data):
    data = json.dumps(data)
    try:
        with open(file_name, 'a+') as file:
            file.write(data + "\n")
    except FileNotFoundError:
        with open(file_name, 'w') as file:
            json.dump(data, file)

previous_values = {}


def on_query_change(query, file_name, session):
    file_name_log = directory + "query_log.json"
    previous_query = previous_values.get(session, '')
    previous_content_to_log = "null" if previous_query == '' else previous_query
    data = {
        'session': session,
        'type': 'query',
        'query_reformulator_type':file_name,
        'previous_content': previous_content_to_log,
        'current_content': query,
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    write_to_jsonl(file_name_log, data)
    previous_values[session] = query


def save_relavance_annotations(session, query, document, rating):
  if rating == "":
    return
  file_name_log = directory + "relavance_annotations.json"
  data = {
        'session': session,
        'type': 'annotation',
        'query': query,
        'document': document,
        'annotation': rating,
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
  write_to_jsonl(file_name_log, data)

In [18]:
query_model_choices = ['castorini/doc2query-t5-base-msmarco','google/flan-t5-large']

In [19]:
def display_session_name(session_id):
    return session_id

In [None]:
import gradio as gr

with gr.Blocks("Query Tuner") as demo:
    session = gr.Textbox(info="Your", visible=False)
    demo.load(fn=generate_cool_name, inputs=None, outputs=session)
    with gr.Tab("🤖 User Interface"):
        with gr.Row():
          # QBE
          prompt_input = gr.Textbox(label="Provide Example Document",info="This document would be used to generate a query",
                                        value=sample_docs[0])
          # Ad-hoc
          interactive_query = gr.Textbox(value=(" "*len(sample_docs[0])), label="Search Query", info="Search and Retrieve documents for one of the generated queries", interactive=True)
        with gr.Row():
          with gr.Column():
            btni3 = gr.Button(value ="⚙️ Generate Query                    ",  size="sm")
            outputs3 = []
            rds3 = []
            for i in range(5):
              with gr.Row():
                with gr.Column():
                  outputs3.append(gr.Textbox(value="", label=f"Query {str(i+1)}", interactive=True))
                  rds3.append(gr.Checkbox(label="Add to query",elem_id=f"cb{str(i)}",interactive=True))
          with gr.Column():
            with gr.Row():
              btn_reform = gr.Button(value="Reformulate", size="sm", css=".gradio-container { width: 2.5vw}")
              btn2 = gr.Button(value="🔍 Search", size="sm", css=".gradio-container { width: 5vw}")
            docs_en2 = []
            rds4 = []
            for i in range(NDOCS):
              with gr.Row():
                with gr.Column():
                  dx = gr.Textbox(value="",  label=f"Document {str(i+1)}", lines=3)
                  docs_en2.append(dx)
                  docs_en2.append(gr.Textbox(value="", label="Translation", lines=3))
                  docs_en2.append(gr.Textbox(value="", label="Score", lines=3, visible=False))
                  with gr.Row():
                    slider = gr.Slider(minimum=0, maximum=10, label=f"Rate Document {i+1}", value=-1)
                    slider.change(save_relavance_annotations, inputs=[session, interactive_query, dx, slider])
                    docs_en2.append(slider)
                    # docs_en2.append(gr.HighlightedText(value="", label=f"Event Annotations of Document {str(i+1)}"))
                    rds4.append(gr.Checkbox(label="Use this document to improve the query", elem_id=f"cb{str(i)}", interactive=True))
            for i in range(5):
              rds3[i].select(fn=add_to_new_query, inputs=[interactive_query, outputs3[i]], outputs=interactive_query)
            save_btn_inter = gr.Button(value="💾 Save Results" , size="sm", css=".gradio-container { width: 5vw}")
        filename_input = gr.Textbox(value='interactive_search', visible=False)
        btn2.click(lambda x:[False]*NDOCS, inputs=None, outputs=rds4)
        btni3.click(lambda x:"", inputs=None, outputs=interactive_query)
        btni3.click(lambda x:[False]*5, inputs=None, outputs=rds3)
    with gr.Tab("⚙️ Configuration"):
        json_choice_list = ['interactive_search.json','query_log.json', 'relavance_annotations.json']
        with gr.Row():
            model_choice_query = gr.Dropdown(
                label="Select Query Generator",
                choices = query_model_choices,
                value=query_model_choices[1] ,
                size="sm"
              )
            prompt_instruction = gr.Dropdown(instructions,label="Choose Instruction",info="Different instructions can invoke a variety of responses from LLMs",value=instructions[1],interactive=True )
            prompt_context = gr.Textbox(label="Sample Document-Query Pairs",
                                    placeholder="Provide text to generate query...", value=dqpairs(ds_eg[1:], qs_eg[1:]),
                                       lines=5)
            btni3.click(query_generatortemp, inputs=[prompt_instruction,prompt_context, prompt_input,model_choice_query, session], outputs=[session,*outputs3])
        with gr.Row():
          reform_method = gr.Dropdown(['Zero-Shot QR', 'Few-Shot QR'],
                                   label="Choose Type of Query Reformulator",info="Additional keywords would be added to your original query",
                                   value='Zero-Shot QR')
          reform_instruction = gr.Textbox(value="Suggest useful keywords to improve the retrieval effectiveness of the query: ", label="Reformulator Instruction", info="Use the following instruction to reform the query inplace", interactive=True)
          reform_context = gr.Textbox(label="Sample Query-Reformed Query Pairs",
                                    placeholder="Provide text to generate query...", value="",
                                       lines=5)
        with gr.Row():
            model_choice_retrival = gr.Dropdown(
                label="Select retrieval pipeline to use", info="The query would be run against the index to retrieve the top 10 documents.",
                choices=retrieval_algos_names,
                value=retrieval_algos_names[0],
              )
            index_choice_retrieval = gr.Dropdown(
                label="Select retrieval index to use", info="The query would be run against this index.",
                choices=dataset_names,
                value=dataset_names[3],
              )
            index_choice_retrieval.change(on_dataset_change, inputs=index_choice_retrieval)
            btn2.click(document_retriever, inputs=[session, interactive_query, filename_input, model_choice_retrival], outputs=docs_en2)
            btn2.click(save_document_auto, inputs = [session, model_choice_retrival,filename_input,interactive_query,*docs_en2])
        with gr.Row():
            dropdown = gr.Dropdown(choices=json_choice_list, label="Select JSON File")
            display_button = gr.Button("Display Recorded Annotations & Logs")
        output = gr.HTML()
        model_choice_query.change(update_model, inputs=model_choice_query)
        display_button.click(display_jsonl_content, inputs=dropdown, outputs=output)
        save_btn_inter.click(save_document_auto, inputs = [session, model_choice_retrival,filename_input,interactive_query,*docs_en2])
        btn_reform.click(append_keywords, inputs=[session, interactive_query, reform_method, reform_instruction, model_choice_query], outputs=interactive_query)
        for i in range(NDOCS):
              rds4[i].select(fn=send_feedback, inputs=[interactive_query, docs_en2[3*i + 1], model_choice_query, session], outputs=interactive_query)



In [None]:
demo.launch(share=True, debug=True)

In [None]:
demo.close()
