In [6]:
import json
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
import numpy as np
import math

nltk.download('punkt')

def load_data(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

def preprocess(text):
    return word_tokenize(text.lower())

def prepare_data(data):
    doc_passages = {}
    term_frequencies = {}
    doc_lengths = {}
    total_tokens = 0

    for doc_id, passages in data.items():
        doc_passages[doc_id] = []
        doc_tf = Counter()
        for passage_id, text in passages.items():
            tokens = preprocess(text)
            doc_passages[doc_id].append((passage_id, tokens))
            tf = Counter(tokens)
            doc_tf += tf
        term_frequencies[doc_id] = doc_tf
        doc_lengths[doc_id] = sum(doc_tf.values())
        total_tokens += doc_lengths[doc_id]

    return doc_passages, term_frequencies, doc_lengths, total_tokens

def query_likelihood(query, term_frequencies, doc_lengths, total_length, mu=2e4):
    tokenized_query = preprocess(query)
    scores = {}
    for doc_id, tf in term_frequencies.items():
        score = 0
        for term in tokenized_query:
            term_freq = tf.get(term, 0)
            doc_len = doc_lengths[doc_id]
            collection_freq = sum(t.get(term, 0) for t in term_frequencies.values())
            smoothed_prob = (term_freq + mu * (collection_freq / total_length)) / (doc_len + mu)
            score += math.log(smoothed_prob) if smoothed_prob > 0 else 0
        scores[doc_id] = score
    return scores

def get_top_document_and_passages(doc_passages, term_frequencies, doc_lengths, total_length, query, mu=2e4, top_n=2):
    scores = query_likelihood(query, term_frequencies, doc_lengths, total_length, mu)
    top_doc = max(scores, key=scores.get)

    passage_scores = []
    for passage_id, tokens in doc_passages[top_doc]:
        score = 0
        for term in preprocess(query):
            term_freq = tokens.count(term)
            passage_len = len(tokens)
            collection_freq = sum(t.get(term, 0) for t in term_frequencies.values())
            smoothed_prob = (term_freq + mu * (collection_freq / total_length)) / (passage_len + mu)
            score += math.log(smoothed_prob) if smoothed_prob > 0 else 0
        passage_scores.append((score, passage_id, ' '.join(tokens)))

    top_passages = sorted(passage_scores, reverse=True, key=lambda x: x[0])[:top_n]
    return top_doc, top_passages

def print_results(doc_id, top_passages):
    print(f"Most Relevant Document ID: {doc_id}")
    for score, passage_id, text in top_passages:
        print(f"Passage ID: {passage_id}, Score: {score:.2f}")
        print(text)
        print()

def main():
    filepath = 'WikiPassageQA/document_passages.json'
    data = load_data(filepath)
    doc_passages, term_frequencies, doc_lengths, total_length = prepare_data(data)
    query = "What is the structure of Australia’s members of parliament?"
    
    top_doc, top_passages = get_top_document_and_passages(doc_passages, term_frequencies, doc_lengths, total_length, query, top_n=4)
    
    print_results(top_doc, top_passages)

if __name__ == "__main__":
    main()


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\titouan\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Most Relevant Document ID: 400
Passage ID: 11, Score: -73.51
the united kingdom elects members of three parliaments : the parliament of the united kingdom , with 650 members elected by the first-past-the-post system to the house of commons , referred to as members of parliament , abbreviated to mp the european parliament , with a maximum of 73 members out of a total of 751 members elected for a five-year term , called members of the european parliament the scottish parliament , with 129 members elected under the additional member system every four years , and called members of the scottish parliament the northern ireland assembly , with 108 members known as members of the legislative assembly . the national assembly for wales , with 60 elected members called assembly member in english , aelod y cynulliad in welsh mps are elected in general elections and by-elections to represent constituencies , and may remain mps until parliament is dissolved , which occurs around five years after the