<a href="https://colab.research.google.com/github/giuliocapecchi/IR_project/blob/main/IR_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
#%pip install torch matplotlib nltk tqdm gdown ir_datasets humanize

# 1. Download and prepare the collection

In [None]:
# chosen_collection can be one of ["vaswani", "msmarco"]

chosen_collection = "msmarco"

In [21]:
import gdown
import ir_datasets
import pandas as pd
import os

if chosen_collection not in ["vaswani", "msmarco"]:
    raise ValueError("chosen_collection must be one of ['vaswani', 'msmarco']")

if chosen_collection == "msmarco":

    os.makedirs('./collection/msmarco', exist_ok=True)

    url_collection = 'https://drive.google.com/uc?id=1_wXJjiwdgc9Kpt7o7atP8oWe-U4Z56hn'
    
    if not os.path.exists('./collection/msmarco/msmarco.tsv'):
        gdown.download(url_collection, './collection/msmarco/msmarco.tsv', quiet=False)
    
    """os.makedirs('./pickles', exist_ok=True)
    if not os.path.exists('./pickles/stats.pkl'):
        gdown.download(url_stats, './pickles/stats.pkl', quiet=False)
    if not os.path.exists('./pickles/lex.pkl'):
        gdown.download(url_lex, './pickles/lex.pkl', quiet=False)
    if not os.path.exists('./pickles/inv.pkl'):
        gdown.download(url_inv, './pickles/inv.pkl', quiet=False)
    if not os.path.exists('./pickles/doc.pkl'):
        gdown.download(url_doc, './pickles/doc.pkl', quiet=False)"""

elif chosen_collection == "vaswani":
    os.makedirs('./collection/vaswani', exist_ok=True)

    vaswani_dataset = ir_datasets.load(chosen_collection)
    docs = list(vaswani_dataset.docs_iter())
    df = pd.DataFrame(docs)
    df['doc_id'] = (df['doc_id'].astype(int) - 1).astype(str)
    # rimuovi i \n da ogni documento
    df['text'] = df['text'].str.replace('\n', ' ')
    if not os.path.exists('./collection/vaswani/vaswani.tsv'):
        df.to_csv('./collection/vaswani/vaswani.tsv', sep='\t', header=False, index=False)

Standard preprocessing but with the usage of the *PyStemmer* library.

In [22]:
import re
import string
import nltk
import Stemmer # PyStemmer


nltk.download("stopwords", quiet=True)
STOPWORDS = set(nltk.corpus.stopwords.words("english"))
STEMMER = Stemmer.Stemmer('english')
# stemmer = nltk.stem.PorterStemmer().stem # much slower than PyStemmer


def preprocess(s):
    # lowercasing
    s = s.lower()
    # ampersand and special chars
    s = re.sub(r"[‘’´“”–-]", "'", s.replace("&", " and ")) # this replaces & with 'and' and normalises quotes
    # acronyms
    s = re.sub(r"\.(?!(\S[^. ])|\d)", "", s) # this removes dots that are not part of an acronym
    # remove punctuation
    s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
    # strip whitespaces
    s = re.sub(r"\s+", " ", s).strip()
    # tokenisation
    tokens = [t for t in s.split() if t not in STOPWORDS]
    # stemming
    return STEMMER.stemWords(tokens)

In [23]:
import time

def profile(f):
    def f_timer(*args, **kwargs):
        start = time.time()
        result = f(*args, **kwargs)
        end = time.time()
        ms = (end - start) * 1000
        print(f"{f.__name__} ({ms:.3f} ms)")
        return result
    return f_timer

In [24]:
# TODO rivedere

import pickle
import humanize
import os
from tqdm import tqdm

def print_pickled_size(var_name, var):
    # If the 'tmp' directory does not exist, we first create it
    os.makedirs('./tmp', exist_ok=True)
    with open(f"./tmp/{var_name}.pickle", 'wb') as f:
        pickle.dump(var, f)
    print(f'{var_name} requires {humanize.naturalsize(os.path.getsize(f"./tmp/{var_name}.pickle"))}')
    os.remove(f"./tmp/{var_name}.pickle")
    os.removedirs('./tmp')


def vbyte_encode(number):
    bytes_list = bytearray()
    while True:
        byte = number & 0x7F # Prendi i 7 bit meno significativi -> 0111 1111 = 0x7F
        number >>= 7 # Shifta a destra di 7 bit
        if number:
            bytes_list.append(byte) # Aggiungo i 7 bit al risultato
        else:
            bytes_list.append(0x80 | byte) # Aggiungo i 7 bit con il bit di continuazione, 0x80 = 1000 0000
            break
    return bytes(bytes_list)

def vbyte_decode(bytes_seq):
    number = 0
    for i, byte in enumerate(bytes_seq):
        number |= (byte & 0x7F) << (7 * i)
        if byte & 0x80:
            break
    return number

def decode_concatenated_vbyte(encoded_bytes):
    decoded_numbers = []
    current_number = 0
    shift_amount = 0
    
    for byte in encoded_bytes:
        if byte & 0x80:  # Bit di continuazione trovato, fine del numero
            current_number |= (byte & 0x7F) << shift_amount
            decoded_numbers.append(current_number)
            current_number = 0
            shift_amount = 0
        else:  # Continuo a comporre il numero
            current_number |= (byte & 0x7F) << shift_amount
            shift_amount += 7
    
    return decoded_numbers

#------------------------------------------------------------------------------------------------------------------------------------------------------------------#

def compress_index(lexicon, inv_d, inv_f):    
    compressed_inv_d = {}
    compressed_inv_f = {}

    for term, (termid, df, _) in tqdm(lexicon.items(), desc="Compressing lists", unit="term"):
        encoded_d = bytearray()
        for x in inv_d[termid]:
            encoded_d.extend(vbyte_encode(x)) 
        assert decode_concatenated_vbyte(encoded_d) == inv_d[termid]
        compressed_inv_d[termid] = encoded_d

        encoded_f = bytearray()
        for x in inv_f[termid]:
            encoded_f.extend(vbyte_encode(x))
        assert decode_concatenated_vbyte(encoded_f) == inv_f[termid]
        compressed_inv_f[termid] = encoded_f

    return compressed_inv_d, compressed_inv_f

## Functions to build the inverted index

In [25]:
import pandas as pd
from collections import Counter
from tqdm.auto import tqdm

def build_index(filepath, batch_size=10000):
    total_documents = sum(1 for _ in open(filepath)) # get total number of documents

    lexicon = {}
    inv_d = {}
    inv_f = {}
    doc_index = []
    total_dl = 0
    num_docs = 0
    termid = 0

    with open(filepath, 'r') as file:        
        batch = []
        
        with tqdm(total=total_documents, desc="Processing documents", unit="doc") as pbar:
            for line in file:
                batch.append(line.strip())
                
                # when the batch is full, we process it
                if len(batch) >= batch_size:
                    for line in batch:
                        doc_id, text = line.split('\t', 1) # '1' specifies the number of splits
                        doc_id = int(doc_id)
                        tokens = preprocess(text)
                        token_tf = Counter(tokens)

                        for token, tf in token_tf.items():
                            if token not in lexicon:
                                lexicon[token] = [termid, 0, 0] # termid, df, tf
                                inv_d[termid], inv_f[termid] = [], [] # docids, freqs
                                termid += 1
                            token_id = lexicon[token][0]  # get termid
                            inv_d[token_id].append(doc_id)  # add doc_id to the list of documents containing the term
                            inv_f[token_id].append(tf)  # add term frequency for this doc
                            lexicon[token][1] += 1  # increment document frequency (df)
                            lexicon[token][2] += tf  # increment total term frequency (tf)

                        doclen = len(tokens)
                        doc_index.append((str(doc_id), doclen))
                        total_dl += doclen
                        num_docs += 1                    
                    # update progress bar for each processed document
                    pbar.update(len(batch))
                    batch = []

            # process the remaining documents in the last batch
            if batch:
                for line in batch:
                    doc_id, text = line.split('\t', 1)
                    doc_id = int(doc_id)
                    tokens = preprocess(text)
                    token_tf = Counter(tokens)

                    for token, tf in token_tf.items():
                        if token not in lexicon:
                            lexicon[token] = [termid, 0, 0]
                            inv_d[termid], inv_f[termid] = [], []
                            termid += 1
                        token_id = lexicon[token][0]  # get termid
                        inv_d[token_id].append(doc_id)  # get doc_id to the list of documents containing the term
                        inv_f[token_id].append(tf)  # get term frequency for this doc
                        lexicon[token][1] += 1  # increment document frequency (df)
                        lexicon[token][2] += tf  # increment total term frequency (tf)

                    doclen = len(tokens)
                    doc_index.append((str(doc_id), doclen))
                    total_dl += doclen
                    num_docs += 1                    
                    pbar.update(1)
                    
     # Calculate average document length (avdl)
    avdl = total_dl / num_docs if num_docs > 0 else 0
                    
    stats = {
        'num_docs': num_docs,
        'num_terms': len(lexicon),
        'num_tokens': total_dl,
        'avdl': avdl  # Add avdl to stats
    }
    return lexicon, {'docids': inv_d, 'freqs': inv_f}, doc_index, stats

In [26]:
import math
import bisect


class InvertedIndex:

    class PostingListIterator:
        def __init__(self, docids, freqs, doc, avdl):
            self.docids = docids
            self.freqs = freqs
            self.pos = 0
            self.doc = doc
            self.total_docs_number = len(doc)
            self.avdl = avdl

        def docid(self):
            if self.is_end_list():
                return math.inf
            return self.docids[self.pos]
        
        def score(self, method='tfidf'):
            if method == 'tfidf':
                return self.score_tfidf()
            elif method == 'bm25':
                return self.score_bm25()
            else:
                raise ValueError("Invalid scoring method")
        
        ###################################################################################        
        def score_tfidf(self): # TODO : check if correct, this is TF-IDF score
            """
            Calculate TF-IDF score of the current document in the posting list.
            """
            if self.is_end_list():
                return math.inf 
            
            tf = self.freqs[self.pos]
                        
            if tf > 0:
                wtd = 1 + math.log(tf)
            else:
                wtd = 0 # avoid log(0)
            
            df = self.len()  # document frequency
            if df > 0:
                idf = math.log(self.total_docs_number / df)
            else:
                idf = 0  # avoid log(0)
            
            # finally calculate tf-idf score
            tfidf = wtd * idf

            return tfidf

        ###################################################################################
        # new score_bm25 function
        # TODO: check if correct -> OK rivista
        def score_bm25(self): # Modified to match the BM25 formula from the slides
            if self.is_end_list():
                return math.inf
            else:
                # Standard BM25 parameters
                b = 0.75
                k_1 = 1.5
                
                # Length of the current document
                dl = self.doc[self.docid()][1]
                
                # Term frequency in the current document
                tf = self.freqs[self.pos]
                
                # Total number of documents in the collection
                N = self.total_docs_number
                
                # Number of documents containing the term
                n = self.len()  # document frequency
                
                # Calculate document length normalization component (B_j)
                B_j = (1 - b) + b * (dl / self.avdl)
                
                # Calculate the IDF component
                idf = math.log( N / n)
                
                # Calculate the BM25 score
                rsv_bm25 = ((tf) / (tf + k_1 * B_j)) * idf
                
                return rsv_bm25
            
            ###################################################################################

        def next(self, target=None):
            if not target:
                if not self.is_end_list():
                    self.pos += 1
            else:
                if target > self.docid():
                    self.pos = bisect.bisect_left(self.docids, target, self.pos)

        def is_end_list(self):
            return self.pos == len(self.docids)


        def len(self):
            return len(self.docids)
        

    def __init__(self, lex, inv, doc, stats):
        self.lexicon = lex
        self.inv = inv
        self.doc = doc
        self.stats = stats

    def num_docs(self):
        return self.stats['num_docs']
    
    def avdl(self):
        return self.stats['avdl']

    def get_posting(self, termid):
        return InvertedIndex.PostingListIterator(self.inv['docids'][termid], self.inv['freqs'][termid], self.doc, self.stats['avdl'])
    
    def get_termids(self, tokens):
        return [self.lexicon[token][0] for token in tokens if token in self.lexicon]

    def get_postings(self, termids):
        return [self.get_posting(termid) for termid in termids]
    

In [27]:
# TODO remove?
# import cProfile
# import pstats

# cProfile.run("build_index('./vaswani.tsv')", "output.prof")
# p = pstats.Stats("output.prof")
# p.sort_stats("cumtime").print_stats(10)
# os.remove("output.prof")

## Building the index on the chosen collection 

Now build up the index for the chosen collection. It is built only if a pickled version of its components doesn't exist already :

In [28]:
import pickle

# If the 'pickles' directory does not exist, we first create it
os.makedirs('./pickles', exist_ok=True)

if chosen_collection == "msmarco":
    try: # try to open the pickled files, else build the index
        with open('./pickles/inv_index.pkl', 'rb') as f:
            inv_index = pickle.load(f)
        
    except FileNotFoundError:
        lex, inv, doc, stats = build_index('./collection/'+chosen_collection + '/'+chosen_collection+'.tsv')

        # Save the lexicon, inverted lists, and document index to disk
        with open('./pickles/lex.pkl', 'wb') as f:
            pickle.dump(lex, f)
        with open('./pickles/inv.pkl', 'wb') as f:
            pickle.dump(inv, f)
        with open('./pickles/doc.pkl', 'wb') as f:
            pickle.dump(doc, f)
        with open('./pickles/stats.pkl', 'wb') as f:
            pickle.dump(stats, f)
                    
        # Compress the inverted lists
        #inv['docids'], inv['freqs'] = compress_index(lex, inv['docids'], inv['freqs'])
        
        inv_index = InvertedIndex(lex, inv, doc, stats)
        with open('./pickles/inv_index.pkl', 'wb') as f:
            pickle.dump(inv_index, f)
else:
    lex, inv, doc, stats = build_index('./collection/'+chosen_collection + '/'+chosen_collection+'.tsv')
    inv_index = InvertedIndex(lex, inv, doc, stats)


print(f"Numero di documenti: {inv_index.num_docs()}")

Numero di documenti: 8841823


In [29]:
# print avdl
print(f"Avdl: {inv_index.avdl()}")

Avdl: 34.687022122021666


In [30]:
#print_pickled_size('inv_index', inv_index)

# 2. Download and prepare queries

In [31]:
import gzip
import os

if chosen_collection not in ["vaswani", "msmarco"]:
    raise ValueError("chosen_collection must be one of ['vaswani', 'msmarco']")

if chosen_collection == "msmarco":
    if not os.path.exists('./collection/msmarco/msmarco-queries.tsv'):
        url = 'https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz'
        gdown.download(url, './collection/msmarco/msmarco-test2019-queries.tsv.gz', quiet=False)
        with gzip.open('./collection/msmarco/msmarco-test2019-queries.tsv.gz', 'rt') as f_in:
            with open('./collection/msmarco/msmarco-queries.tsv', 'w') as f_out:
                f_out.write(f_in.read())
        os.remove('./collection/msmarco/msmarco-test2019-queries.tsv.gz') # delete the compressed file
    queries = pd.read_csv('./collection/msmarco/msmarco-queries.tsv', sep='\t', header=None)
    queries.columns = ['qid', 'query']
    print("Number of queries: ",len(queries))

    if not os.path.exists('./collection/msmarco/msmarco-qrels.txt'):
        url = 'https://trec.nist.gov/data/deep/2019qrels-pass.txt'
        gdown.download(url, './collection/msmarco/msmarco-qrels.txt', quiet=False)
    qrels = pd.read_csv('./collection/msmarco/msmarco-qrels.txt', sep=' ', header=None)
    qrels.columns = ['qid', 'Q0', 'docid', 'rating']
    print("Number of relevance judgments: ",len(qrels))


elif chosen_collection == "vaswani":
    queries = pd.DataFrame(vaswani_dataset.queries_iter())
    queries.columns = ['qid', 'query']
    print("Number of queries: ",len(list(vaswani_dataset.queries_iter()))) 
    if not os.path.exists('./collection/vaswani/vaswani-queries.tsv'):
        queries.to_csv('./collection/vaswani/vaswani-queries.tsv', sep='\t', header=False, index=False)
    qrels = pd.DataFrame(vaswani_dataset.qrels_iter()) 
    qrels.columns = ['qid', 'docid', 'relevance', 'iteration']
    qrels['docid'] = (qrels['docid'].astype(int) - 1).astype(str) # convert to 0-based indexing

    if not os.path.exists('./collection/vaswani/vaswani-qrels.txt'):
        qrels.to_csv('./collection/vaswani/vaswani-qrels.txt', sep='\t', header=False, index=False)
    print("Number of relevance judgments: ",len(list(vaswani_dataset.qrels_iter())))

Number of queries:  200
Number of relevance judgments:  9260


In [32]:
from collections import namedtuple


class QueriesDataset:
    def __init__(self, df):
        self.queries = [Query(row.query_id, row.text) for row in df.itertuples()]

    def queries_iter(self):
        return iter(self.queries)

    def queries_count(self):
        return len(self.queries)
    
    def get_query(self, query_id):
        return self.queries[query_id]


Query = namedtuple('Query', ['query_id', 'text'])
queries.columns = ['query_id', 'text']
queries_dataset = QueriesDataset(queries)
print("The number of queries is: ", queries_dataset.queries_count())

The number of queries is:  200


Let's prepare the functions necessary to perform TAAT and DAAT query processing

First, we need a TopQueue class, which stores the top  K  (score, docid) tuples, using an heap 

In [33]:
import heapq

class TopQueue:
    def __init__(self, k=10, threshold=0.0):
        self.queue = []
        self.k = k
        self.threshold = threshold

    def size(self):
        return len(self.queue)

    def would_enter(self, score):
        return score > self.threshold

    def clear(self, new_threshold=None):
        self.queue = []
        if new_threshold:
            self.threshold = new_threshold

    def __repr__(self):
        return f'<{self.size()} items, th={self.threshold} {self.queue}'

    def insert(self, docid, score):
        if score > self.threshold:
            if self.size() >= self.k:
                heapq.heapreplace(self.queue, (score, docid))
            else:
                heapq.heappush(self.queue, (score, docid))
            if self.size() >= self.k:
                self.threshold = max(self.threshold, self.queue[0][0])
            return True
        return False

#print(sorted(topq.queue, reverse=True)) # print the queue sorted by score

### TAAT

In [34]:
from collections import defaultdict

def taat(postings, k=10, method='bm25'):
    A = defaultdict(float)
    for posting in postings:
        current_docid = posting.docid()
        while current_docid != math.inf:
            A[current_docid] += posting.score(method)
            posting.next()
            current_docid = posting.docid()
    top = TopQueue(k)
    for docid, score in A.items():
        top.insert(docid, score)
    return sorted(top.queue, reverse=True)


def query_process(query, index):
    qtokens = set(preprocess(query))
    qtermids = index.get_termids(qtokens)
    postings = index.get_postings(qtermids)
    return taat(postings)

### DAAT

In [35]:
import math

def min_docid(postings):
    min_docid = math.inf
    for p in postings:
        if not p.is_end_list():
            min_docid = min(p.docid(), min_docid)
    return min_docid

def daat(postings, k=10, method='bm25'):
    top = TopQueue(k)
    current_docid = min_docid(postings)
    while current_docid != math.inf:
        score = 0
        next_docid = math.inf
        for posting in postings:
            if posting.docid() == current_docid:
                score += posting.score(method)
                posting.next()
            if not posting.is_end_list():
                next_docid = posting.docid()
        top.insert(current_docid, score)
        current_docid = next_docid
    return sorted(top.queue, reverse=True)

def query_process(query, index):
    qtokens = set(preprocess(query))
    qtermids = index.get_termids(qtokens)
    postings = index.get_postings(qtermids)
    return daat(postings)

In [None]:
from tqdm import tqdm
import cProfile
import pstats

@profile
def query_processing(queries_iter, fn):
    for q in tqdm(queries_iter, desc="Processing queries", total=queries_dataset.queries_count(), unit="query"):
        query = preprocess(q.text)
        termids = inv_index.get_termids(query)
        postings = inv_index.get_postings(termids)
        res = fn(postings)


cProfile.run("query_processing(queries_dataset.queries_iter(), taat)", "./perfm/result.prof")
p = pstats.Stats("./perfm/result.prof")
p.sort_stats("cumtime").print_stats(25)

---

# 3. Evaluation

A relevance assessment (called ***qrel*** in `ir_datasets`) is composed by:
* a **topic id** (called *query_id* in `ir_datasets`) as in a topic,
* a **docno** (called *doc_id* in `ir_datasets`) as in a document,
* a **judgement** (called *relevance* in `ir_datasets`) as a binary or graded relevance judgment/label, and
* an **iteration**, **UNUSED** and always equal to the string `'0'`.

In [None]:
# get the qrels for the chosen collection
if chosen_collection == "vaswani":
    sep = '\t'
else:
    sep = ' '

qrels = pd.read_csv('./collection/'+chosen_collection+'/'+chosen_collection+'-qrels.txt', sep=sep, header=None)

if chosen_collection == "vaswani":
    qrels.columns = ['query_id', 'doc_id', 'relevance', 'iteration']
else:
    qrels.columns = ['query_id', 'Q0', 'doc_id', 'relevance']
    qrels['query_id'] = qrels['query_id'].apply(str)
    qrels['doc_id'] = qrels['doc_id'].apply(str)
    qrels['relevance'] = qrels['relevance'].apply(int)

print("Number of relevance judgments: ",len(qrels))

Number of relevance judgments:  9260


In [41]:
import os

def create_run_file(queries_iter, fn, k, method, run_id, output_file):
    """
    Preprocess the queries and write the results to a run file.
    :param queries_iter: Query iterator
    :param fn: Function to process the postings and return the results in the format (score, docid)
    :param run_id: Name identifier for the run
    :param output_file: Output run file
    """
    if not os.path.exists('./results'):
        os.makedirs('./results')
    with open(f"./results/{output_file}", "w") as f:
        for q in queries_iter:
            topic_id = q.query_id 
            query = preprocess(q.text)
            termids = inv_index.get_termids(query)
            postings = inv_index.get_postings(termids)
            results = fn(postings, k=k, method=method)
            
            if results:
                # Write results to the run file
                for rank, (score, docno) in enumerate(results, start=1):
                    line = f"{topic_id}\tQ0\t{docno}\t{rank}\t{score:.6f}\t{run_id}\n"
                    f.write(line)
            else:
                # Annotate that no results were found for this query
                line = f"{topic_id}\tQ0\tNO_RESULTS\t0\t0.0\t{run_id}\n"
                f.write(line)

    print(f"Run file {output_file} produced successfully.")

In [72]:
if not os.path.exists('./results'):
    os.makedirs('./results')
    create_run_file(queries_dataset.queries_iter(), taat, 1000, 'tfidf', "run_tfidf", f"{chosen_collection}_tfidf_200_queries.run")
    create_run_file(queries_dataset.queries_iter(), taat, 1000, 'bm25', "run_bm25", f"{chosen_collection}_bm25_200_queries.run")

In [44]:
print(inv_index.num_docs())

8841823


# IR Measures

In [None]:
import ir_measures
from ir_measures import *

# Load the run
run_file_tfidf = list(ir_measures.read_trec_run(f'./results/{chosen_collection}_tfidf_200_queries.run'))

run_file_bm25 = list(ir_measures.read_trec_run(f'./results/{chosen_collection}_bm25_200_queries.run'))


In [47]:
# The [Mean] Average Precision ([M]AP). 
# The average precision of a single query is the mean of the precision scores at each relevant item returned in a search results list.

print(f"TFIDF: ", ir_measures.calc_aggregate([AP], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([AP], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([AP(judged_only=True)], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([AP(judged_only=True)], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([AP(rel=2, judged_only=True)], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([AP(rel=2, judged_only=True)], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([AP(rel=3, judged_only=True)], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([AP(rel=3, judged_only=True)], qrels, run_file_bm25))

TFIDF:  {AP: 0.26743796793877844}
BM25:  {AP: 0.36512401777983017}
TFIDF:  {AP(judged_only=True): 0.4323169834401592}
BM25:  {AP(judged_only=True): 0.4826520578883197}
TFIDF:  {AP(rel=2,judged_only=True): 0.3196569021180839}
BM25:  {AP(rel=2,judged_only=True): 0.34134864158104516}
TFIDF:  {AP(rel=3,judged_only=True): 0.1657749231625388}
BM25:  {AP(rel=3,judged_only=True): 0.18336371701465518}


In [48]:
# The normalized Discounted Cumulative Gain (nDCG). 
# Uses graded labels - systems that put the highest graded documents at the top of the ranking. 
# It is normalized wrt. the Ideal NDCG, i.e. documents ranked in descending order of graded label.

print(f"TFIDF: ", ir_measures.calc_aggregate([nDCG], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([nDCG], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([nDCG@10], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([nDCG@10], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([nDCG(judged_only=True)], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([nDCG(judged_only=True)], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([nDCG(judged_only=True)@10], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([nDCG(judged_only=True)@10], qrels, run_file_bm25))

TFIDF:  {nDCG: 0.5164447977209103}
BM25:  {nDCG: 0.5939273906659651}
TFIDF:  {nDCG@10: 0.41352178810945667}
BM25:  {nDCG@10: 0.4727669870653633}
TFIDF:  {nDCG(judged_only=True): 0.5896083919582058}
BM25:  {nDCG(judged_only=True): 0.6386108191944135}
TFIDF:  {nDCG(judged_only=True)@10: 0.48249339364752636}
BM25:  {nDCG(judged_only=True)@10: 0.47913114994269124}


In [49]:
# Binary Preference (Bpref). 
# This measure examines the relative ranks of judged relevant and non-relevant documents. 
# Non-judged documents are not considered.

print(f"TFIDF: ", ir_measures.calc_aggregate([Bpref], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([Bpref], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([Bpref(rel=2)], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([Bpref(rel=2)], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([Bpref(rel=3)], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([Bpref(rel=3)], qrels, run_file_bm25))

TFIDF:  {Bpref: 0.4571673693044115}
BM25:  {Bpref: 0.4909895312262259}
TFIDF:  {Bpref(rel=2): 0.30611709329769105}
BM25:  {Bpref(rel=2): 0.31312768729137774}
TFIDF:  {Bpref(rel=3): 0.12365750849455547}
BM25:  {Bpref(rel=3): 0.1311730279512321}


In [50]:
# BPercentage of results in the top k (cutoff) results that have relevance judgments. 
# Equivalent to P@k with a rel lower than any judgment.

print(f"TFIDF: ", ir_measures.calc_aggregate([Judged], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([Judged], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([Judged@5], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([Judged@5], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([Judged@10], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([Judged@10], qrels, run_file_bm25))

print(f"TFIDF: ", ir_measures.calc_aggregate([Judged@100], qrels, run_file_tfidf))
print(f"BM25: ", ir_measures.calc_aggregate([Judged@100], qrels, run_file_bm25))

TFIDF:  {Judged: 0.10595348837209306}
BM25:  {Judged: 0.12069767441860466}
TFIDF:  {Judged@5: 0.8372093023255816}
BM25:  {Judged@5: 0.9813953488372092}
TFIDF:  {Judged@10: 0.7581395348837208}
BM25:  {Judged@10: 0.9488372093023255}
TFIDF:  {Judged@100: 0.41418604651162794}
BM25:  {Judged@100: 0.5362790697674419}


In [51]:
# TODO : da rivedere e capire perchè è qui

'''count = 0
for metric in ir_measures.iter_calc([P@5, P(rel=2)@5, nDCG@10, AP, AP(rel=2), Bpref, Bpref(rel=2), Judged@10], qrels, run_file):
  print(metric)
  count += 1
  if count >= 10: break # only show top 10 items

run_file_p_rel2_5 = {m.query_id: m.value for m in ir_measures.iter_calc([Bpref(rel=2)], qrels, run_file)}

from scipy.stats import ttest_rel
qids = list(run_file_p_rel2_5.keys())
#ttest_rel([run_file_p_rel2_5[v] for v in qids])'''


'count = 0\nfor metric in ir_measures.iter_calc([P@5, P(rel=2)@5, nDCG@10, AP, AP(rel=2), Bpref, Bpref(rel=2), Judged@10], qrels, run_file):\n  print(metric)\n  count += 1\n  if count >= 10: break # only show top 10 items\n\nrun_file_p_rel2_5 = {m.query_id: m.value for m in ir_measures.iter_calc([Bpref(rel=2)], qrels, run_file)}\n\nfrom scipy.stats import ttest_rel\nqids = list(run_file_p_rel2_5.keys())\n#ttest_rel([run_file_p_rel2_5[v] for v in qids])'

---

# PyTerrier

In [None]:
import pyterrier as pt
#import pandas as pd
#import os

# note that pt.started() and pt.init() are deprecated

if not pt.java.started():
    pt.java.add_option('-Dtrec.encoding=UTF-8')
    pt.java.init()

In [54]:
dataset = pt.get_dataset("msmarco_passage")

print(dataset.get_topics("test-2019").head())
print(dataset.get_qrels("test-2019").head())

       qid                                              query
0  1108939                  what slows down the flow of blood
1  1112389             what is the county for grand rapids mn
2   792752                                     what is ruclip
3  1119729  what do you do when you have a nosebleed from ...
4  1105095                  where is sugar lake lodge located
     qid    docno  label
0  19335  1017759      0
1  19335  1082489      0
2  19335   109063      0
3  19335  1160863      0
4  19335  1160871      0


In [55]:
# print getCollectionStatistics
index = pt.IndexFactory.of(dataset.get_index(variant='terrier_stemmed'))
print(index.getCollectionStatistics())

Number of documents: 8841823
Number of terms: 1170682
Number of postings: 215238456
Number of fields: 1
Number of tokens: 288759529
Field names: [text]
Positions:   false



In [76]:
# print the number of queries
print(len(dataset.get_topics("test-2019")))

# print the number of relevance judgments
print(len(dataset.get_qrels("test-2019")))

# print the number of unique queries ids in the qrels
print(len(dataset.get_qrels("test-2019")["qid"].unique()))

200
9260
43


In [56]:
# Old approach

'''
# Load the pickled inverted index
df = pd.read_csv("collection/msmarco/msmarco.tsv", sep="\t", header=None)

# assign columns
df.columns = ["docno", "text"]

# Convert columns to strings
df["docno"] = df["docno"].astype(str)
df["text"] = df["text"].astype(str)


# Display the DataFrame to verify
print(df.head())
'''

'\n# Load the pickled inverted index\ndf = pd.read_csv("collection/msmarco/msmarco.tsv", sep="\t", header=None)\n\n# assign columns\ndf.columns = ["docno", "text"]\n\n# Convert columns to strings\ndf["docno"] = df["docno"].astype(str)\ndf["text"] = df["text"].astype(str)\n\n\n# Display the DataFrame to verify\nprint(df.head())\n'

In [57]:
# Old approach

'''

# Convert DataFrame to a list of dictionaries
docs = [{"docno": docno, "text": text} for docno, text in zip(df["docno"], df["text"])]

# get the root path of the current directory
cwd = os.getcwd()
# from this get the path to index_3docs
index_path = os.path.join(cwd, "index_3docs")

# if the folder does not exists or is empty
if not os.path.exists(index_path) or len(os.listdir(index_path)) == 0:
    # indexer = pt.DFIndexer("./index_3docs", overwrite=True) # deprecated
    indexer = pt.IterDictIndexer(index_path, overwrite=True)
    indexref = indexer.index(docs)
    indexref.toString()
    
'''

'\n\n# Convert DataFrame to a list of dictionaries\ndocs = [{"docno": docno, "text": text} for docno, text in zip(df["docno"], df["text"])]\n\n# get the root path of the current directory\ncwd = os.getcwd()\n# from this get the path to index_3docs\nindex_path = os.path.join(cwd, "index_3docs")\n\n# if the folder does not exists or is empty\nif not os.path.exists(index_path) or len(os.listdir(index_path)) == 0:\n    # indexer = pt.DFIndexer("./index_3docs", overwrite=True) # deprecated\n    indexer = pt.IterDictIndexer(index_path, overwrite=True)\n    indexref = indexer.index(docs)\n    indexref.toString()\n    \n'

In [58]:
'''
indexref = pt.IndexRef.of(os.path.abspath("index_3docs/data.properties"))
index = pt.IndexFactory.of(indexref)
print(index.getCollectionStatistics())
'''

'\nindexref = pt.IndexRef.of(os.path.abspath("index_3docs/data.properties"))\nindex = pt.IndexFactory.of(indexref)\nprint(index.getCollectionStatistics())\n'

In [59]:
# Old approach

'''
import pandas as pd
import re
import string


# File path for the input file
file_path = "collection/msmarco/msmarco-queries.tsv"  # Replace with your file path
# Read the input file
queries = pd.read_csv(file_path, sep="\t", header=None, names=["qid", "query"])

# Apply preprocessing to each query
queries["processed_query"] = queries["query"].apply(preprocess)

# Save the processed queries to a new file
output_path = "processed_queries.txt"
queries[["qid", "processed_query"]].to_csv(output_path, sep="\t", index=False)

# Read the processed file with the correct column names
processed_queries = pd.read_csv(output_path, sep="\t", names=["qid", "query"], skiprows=1)
'''

'\nimport pandas as pd\nimport re\nimport string\n\n\n# File path for the input file\nfile_path = "collection/msmarco/msmarco-queries.tsv"  # Replace with your file path\n# Read the input file\nqueries = pd.read_csv(file_path, sep="\t", header=None, names=["qid", "query"])\n\n# Apply preprocessing to each query\nqueries["processed_query"] = queries["query"].apply(preprocess)\n\n# Save the processed queries to a new file\noutput_path = "processed_queries.txt"\nqueries[["qid", "processed_query"]].to_csv(output_path, sep="\t", index=False)\n\n# Read the processed file with the correct column names\nprocessed_queries = pd.read_csv(output_path, sep="\t", names=["qid", "query"], skiprows=1)\n'

In [60]:
# Old approach
'''
import ast

def concatenate_list(query):
    """
    Funzione per concatenare gli elementi di una lista di termini in una stringa.
    """
    try:
        query_list = ast.literal_eval(query)  # Converte la stringa in lista
        return " ".join(query_list)  # Concatena gli elementi della lista con uno spazio
    except (ValueError, SyntaxError):
        return query  # Restituisci il valore originale se non è una lista valida


# Step 1: Leggi il file delle query originali
file_path = "collection/msmarco/msmarco-queries.tsv"  # Sostituisci con il percorso del file
queries = pd.read_csv(file_path, sep="\t", header=None, names=["qid", "query"])

# Step 2: Applica la funzione preprocess a tutte le query
queries["processed_query"] = queries["query"].apply(preprocess)

# Step 3: Salva le query preprocessate in un file temporaneo
temp_output_path = "processed_queries_temp.txt"
queries[["qid", "processed_query"]].to_csv(temp_output_path, sep="\t", index=False)

# Step 4: Rileggi il file temporaneo per applicare concatenate_list
processed_queries = pd.read_csv(temp_output_path, sep="\t")

# Applica concatenate_list alla colonna 'processed_query'
processed_queries["qid"] = processed_queries["qid"].astype(str)
processed_queries["query"] = processed_queries["processed_query"].apply(concatenate_list)
processed_queries = processed_queries[["qid", "query"]]
# Step 5: Salva il file finale con le colonne 'qid' e 'query'
final_output_path = "final_processed_queries.txt"
processed_queries[["qid", "query"]].to_csv(final_output_path, sep="\t", index=False)

# Output per conferma
print(f"File processato salvato in: {final_output_path}")
print(processed_queries.head())

processed_queries["qid"] = processed_queries["qid"].astype(str)


qrels = pd.read_csv('./collection/'+chosen_collection+'/'+chosen_collection+'-qrels.txt', sep=' ', header=None)
qrels.columns = ['qid', 'Q0', 'docno', 'relevance']
qrels['qid'] = qrels['qid'].apply(str)
qrels['docno'] = qrels['docno'].apply(str)
#qrels['relevance'] = qrels['relevance'].apply(int)
'''

'\nimport ast\n\ndef concatenate_list(query):\n    """\n    Funzione per concatenare gli elementi di una lista di termini in una stringa.\n    """\n    try:\n        query_list = ast.literal_eval(query)  # Converte la stringa in lista\n        return " ".join(query_list)  # Concatena gli elementi della lista con uno spazio\n    except (ValueError, SyntaxError):\n        return query  # Restituisci il valore originale se non è una lista valida\n\n\n# Step 1: Leggi il file delle query originali\nfile_path = "collection/msmarco/msmarco-queries.tsv"  # Sostituisci con il percorso del file\nqueries = pd.read_csv(file_path, sep="\t", header=None, names=["qid", "query"])\n\n# Step 2: Applica la funzione preprocess a tutte le query\nqueries["processed_query"] = queries["query"].apply(preprocess)\n\n# Step 3: Salva le query preprocessate in un file temporaneo\ntemp_output_path = "processed_queries_temp.txt"\nqueries[["qid", "processed_query"]].to_csv(temp_output_path, sep="\t", index=False)\n\n

In [61]:
# Old approach TODO rewrite with new things
'''
queries_qids = set(processed_queries["qid"])
qrels_qids = set(qrels["qid"])

common_qids = queries_qids & qrels_qids
missing_in_qrels = queries_qids - qrels_qids
missing_in_queries = qrels_qids - queries_qids

# Mostra i risultati
print(f"Numero totale di qid in queries: {len(queries_qids)}")
print(f"Numero totale di qid in qrels: {len(qrels_qids)}")
print(f"Numero di qid comuni: {len(common_qids)}")
print(f"Qid mancanti nei qrels: {len(missing_in_qrels)}")
print(f"Qid mancanti nelle queries: {len(missing_in_queries)}")
'''

'\nqueries_qids = set(processed_queries["qid"])\nqrels_qids = set(qrels["qid"])\n\ncommon_qids = queries_qids & qrels_qids\nmissing_in_qrels = queries_qids - qrels_qids\nmissing_in_queries = qrels_qids - queries_qids\n\n# Mostra i risultati\nprint(f"Numero totale di qid in queries: {len(queries_qids)}")\nprint(f"Numero totale di qid in qrels: {len(qrels_qids)}")\nprint(f"Numero di qid comuni: {len(common_qids)}")\nprint(f"Qid mancanti nei qrels: {len(missing_in_qrels)}")\nprint(f"Qid mancanti nelle queries: {len(missing_in_queries)}")\n'

In [62]:
'''import pandas as pd

# Percorso del file da leggere
input_file = "collection/msmarco/msmarco-qrels.txt"  # Sostituisci con il percorso del tuo file
output_file = "filtered_msmarco-qrels.txt"  # Nome del file di output

# Leggi il file con pandas
df = pd.read_csv(input_file, sep=" ", header=None, names=["qid", "Q0", "docno", "relevance"])

# Filtra le righe in cui la colonna 'relevance' è diversa da zero
filtered_df = df[df["relevance"] != 0]

# Salva il risultato in un nuovo file
filtered_df.to_csv(output_file, sep=" ", index=False, header=False)

print(f"File salvato in: {output_file}")

file_path = "filtered_msmarco-qrels.txt"

# Leggi il file con pandas
df = pd.read_csv(file_path, sep=" ", header=None, names=["qid", "Q0", "docno", "relevance"])

# Conta il numero di valori univoci nella colonna 'qid'
unique_values_count = df["qid"].nunique()

# Stampa il numero di valori univoci
print(f"Numero di valori univoci nella prima colonna ('qid'): {unique_values_count}")
'''

"""queries_qids = set(processed_queries["qid"])
qrels_qids = set(qrels["qid"])

common_qids = queries_qids & qrels_qids
missing_in_qrels = queries_qids - qrels_qids
missing_in_queries = qrels_qids - queries_qids

# Mostra i risultati
print(f"Numero totale di qid in queries: {len(queries_qids)}")
print(f"Numero totale di qid in qrels: {len(qrels_qids)}")
print(f"Numero di qid comuni: {len(common_qids)}")
print(f"Qid mancanti nei qrels: {len(missing_in_qrels)}")
print(f"Qid mancanti nelle queries: {len(missing_in_queries)}")"""

'queries_qids = set(processed_queries["qid"])\nqrels_qids = set(qrels["qid"])\n\ncommon_qids = queries_qids & qrels_qids\nmissing_in_qrels = queries_qids - qrels_qids\nmissing_in_queries = qrels_qids - queries_qids\n\n# Mostra i risultati\nprint(f"Numero totale di qid in queries: {len(queries_qids)}")\nprint(f"Numero totale di qid in qrels: {len(qrels_qids)}")\nprint(f"Numero di qid comuni: {len(common_qids)}")\nprint(f"Qid mancanti nei qrels: {len(missing_in_qrels)}")\nprint(f"Qid mancanti nelle queries: {len(missing_in_queries)}")'

In [63]:
'''file_path = "collection/msmarco/msmarco-qrels.txt"

df = pd.read_csv(file_path, sep=" ", header=None, names=["qid", "Q0", "docno", "relevance"])

# Conta il numero di valori univoci nella colonna 'qid'
unique_values_count = df["qid"].nunique()

# Stampa il numero di valori univoci
print(f"Numero di valori univoci nella prima colonna ('qid'): {unique_values_count}")'''

'file_path = "collection/msmarco/msmarco-qrels.txt"\n\ndf = pd.read_csv(file_path, sep=" ", header=None, names=["qid", "Q0", "docno", "relevance"])\n\n# Conta il numero di valori univoci nella colonna \'qid\'\nunique_values_count = df["qid"].nunique()\n\n# Stampa il numero di valori univoci\nprint(f"Numero di valori univoci nella prima colonna (\'qid\'): {unique_values_count}")'

In [64]:
from pyterrier.measures import *

queries = dataset.get_topics("test-2019")
qrels = dataset.get_qrels("test-2019")

# Definizione dei retriever
#TF_IDF = pt.terrier.Retriever(index, wmodel="TF_IDF")
#BM25 =  pt.terrier.Retriever(index, wmodel="BM25")
TF_IDF = pt.terrier.Retriever.from_dataset('msmarco_passage', 'terrier_stemmed', wmodel='TF_IDF')
BM25 = pt.terrier.Retriever.from_dataset('msmarco_passage', 'terrier_stemmed', wmodel='BM25')

In [65]:
pt.Experiment(
    [TF_IDF, BM25],
    queries,
    qrels,
    eval_metrics=[AP, AP(judged_only=True), AP(rel=2, judged_only=True), AP(rel=3, judged_only=True)]
)

# The [Mean] Average Precision ([M]AP). 
# The average precision of a single query is the mean of the precision scores at each relevant item returned in a search results list.

Unnamed: 0,name,AP,AP(judged_only=True),"AP(rel=2,judged_only=True)","AP(rel=3,judged_only=True)"
0,TerrierRetr(TF_IDF),0.369486,0.486215,0.345577,0.196327
1,TerrierRetr(BM25),0.370004,0.486302,0.345722,0.19646


In [66]:
pt.Experiment(
    [TF_IDF, BM25],
    queries,
    qrels,
    eval_metrics=[nDCG@10, nDCG(judged_only=True)@10, nDCG, nDCG(judged_only=True)]
)

# The normalized Discounted Cumulative Gain (nDCG). 
# Uses graded labels - systems that put the highest graded documents at the top of the ranking. 
# It is normalized wrt. the Ideal NDCG, i.e. documents ranked in descending order of graded label.

Unnamed: 0,name,nDCG@10,nDCG(judged_only=True)@10,nDCG,nDCG(judged_only=True)
0,TerrierRetr(TF_IDF),0.47831,0.48248,0.593198,0.635891
1,TerrierRetr(BM25),0.47954,0.48371,0.593433,0.635943


In [67]:
pt.Experiment(
    [TF_IDF, BM25],
    queries,
    qrels,
    eval_metrics=[Bpref, Bpref(rel=2), Bpref(rel=3)],
)

# Binary Preference (Bpref). 
# This measure examines the relative ranks of judged relevant and non-relevant documents. 
# Non-judged documents are not considered.

Unnamed: 0,name,Bpref,Bpref(rel=2),Bpref(rel=3)
0,TerrierRetr(TF_IDF),0.494397,0.316205,0.133529
1,TerrierRetr(BM25),0.494534,0.316478,0.133637


In [68]:
pt.Experiment(
    [TF_IDF, BM25],
    queries,
    qrels,
    eval_metrics=[Judged, Judged@5, Judged@10, Judged@100],
)

# BPercentage of results in the top k (cutoff) results that have relevance judgments. 
# Equivalent to P@k with a rel lower than any judgment.

Unnamed: 0,name,Judged,Judged@5,Judged@10,Judged@100
0,TerrierRetr(TF_IDF),0.142977,0.976744,0.95814,0.551395
1,TerrierRetr(BM25),0.143116,0.976744,0.95814,0.551628


In [69]:
'''pt.Experiment(
    [TF_IDF, BM25, PL2],
    processed_queries,
    qrels,
    eval_metrics=["map"],
    round={"map" : 4},
    names=['TF-IDF', 'BM25', 'PL2'],
    baseline=0,
    correction='b'
)'''

'pt.Experiment(\n    [TF_IDF, BM25, PL2],\n    processed_queries,\n    qrels,\n    eval_metrics=["map"],\n    round={"map" : 4},\n    names=[\'TF-IDF\', \'BM25\', \'PL2\'],\n    baseline=0,\n    correction=\'b\'\n)'

In [70]:
'''pt.Experiment(
    [TF_IDF, BM25, PL2],
    processed_queries,
    qrels,
    eval_metrics=["map"],
    perquery=True
)'''

'pt.Experiment(\n    [TF_IDF, BM25, PL2],\n    processed_queries,\n    qrels,\n    eval_metrics=["map"],\n    perquery=True\n)'

In [71]:
'''pt.Experiment(
    [TF_IDF, BM25, PL2],
    processed_queries,
    qrels,
    eval_metrics=["map"],
    round={"map" : 4},
    names=['TF-IDF', 'BM25', 'PL2'],
    baseline=0,
    correction='b',
    save_dir='./'
)'''

'pt.Experiment(\n    [TF_IDF, BM25, PL2],\n    processed_queries,\n    qrels,\n    eval_metrics=["map"],\n    round={"map" : 4},\n    names=[\'TF-IDF\', \'BM25\', \'PL2\'],\n    baseline=0,\n    correction=\'b\',\n    save_dir=\'./\'\n)'

# Trying it out

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import time


# UI elements
search_bar = widgets.Text(
    placeholder='Type in a query...',
    description='Search:',
    layout=widgets.Layout(width='80%')
)
search_button = widgets.Button(
    description='Search',
    button_style='success',
    tooltip='Execute the query',
    icon='search'
)

score_function_rbtn = widgets.RadioButtons(options=['TF-IDF', 'BM25'], description='Scoring function:', disabled=False)
algo_rbtn = widgets.RadioButtons(options=['TAAT', 'DAAT'], description='Algorithm:', disabled=False)
_style = widgets.HTML(
    "<style>.widget-radio-box {flex-direction: row !important;}.widget-radio-box"
    " label{margin:2px !important;width: 100px !important;}</style>",
    layout=widgets.Layout(display="none"),
)

output_area = widgets.Output()


def on_search_click(b):
    with output_area:
        clear_output()  # clean previous output
        query = search_bar.value
        if not query.strip():
            print("Please, type in a query.")
            return
        
        selected_scoring_function = score_function_rbtn.value
        print(f"Selected scoring function: {selected_scoring_function}")
        if selected_scoring_function == 'TF-IDF':
            method = 'tfidf'
        else:   
            method = 'bm25'

        selected_algorithm = algo_rbtn.value
        print(f"Selected Algorithm: {selected_algorithm}")
        
        start_time = time.time()
        # --- QUERY EXECUTION ---
        processed_query = preprocess(query)
        termids = inv_index.get_termids(processed_query)
        postings = inv_index.get_postings(termids)
        
        if selected_algorithm == 'TAAT':
            results = taat(postings, method=method)
        else:
            results = daat(postings, method=method)
        # ------------------------
        elapsed_time = (time.time() - start_time) * 1000 # convert in ms

        # finally show the results
        print(f"Found: {len(results)} documents\n")
        for res in results:
            res = (round(res[0], 4), res[1]) # TODO : si potrebbe spostare direttamente nella score function
            print(f" - {res}")
        print(f"\nExecution time: {elapsed_time:.2f} ms")

search_button.on_click(on_search_click)
# link search button to the enter key
search_bar.continuous_update = False
search_bar.observe(on_search_click, names='value')

top_row = widgets.HBox([score_function_rbtn,_style])
middle_row = widgets.HBox([algo_rbtn,_style])
bottom_row = widgets.HBox([search_bar, search_button])

# finally display the UI
display(widgets.VBox([top_row,middle_row, bottom_row]))
display(output_area)
