In [58]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt')

# Simplified preprocess function that just handles lowercasing
def preprocess(text):
    return text.lower()

# Function to load data (assuming JSON and already imported json library)
def load_data(filepath):
    import json
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\titouan\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [59]:
def prepare_vsm(data):
    doc_names = []
    corpus = []
    doc_passages = {}  # Stores passages by document for later retrieval

    for doc_id, passages in data.items():
        doc_passages[doc_id] = []
        doc_text = []
        for passage_id, text in passages.items():
            processed_text = preprocess(text)
            doc_names.append((doc_id, passage_id))
            corpus.append(processed_text)
            doc_passages[doc_id].append((passage_id, processed_text))
            doc_text.append(processed_text)  # Collect all texts to a single document list
        # Append the combined text of all passages in one document to the corpus
        corpus.append(" ".join(doc_text))
        doc_names.append((doc_id, 'doc'))

    vectorizer = TfidfVectorizer(stop_words='english')
    tfidf_matrix = vectorizer.fit_transform(corpus)
    return vectorizer, tfidf_matrix, doc_names, doc_passages

def get_most_relevant_document(vectorizer, tfidf_matrix, doc_names, query):
    tokenized_query = preprocess(query)
    query_vector = vectorizer.transform([tokenized_query])
    cos_similarities = cosine_similarity(query_vector, tfidf_matrix).flatten()

    # Only consider full document entries (identified by 'doc')
    doc_scores = {}
    for idx, (doc_id, type_id) in enumerate(doc_names):
        if type_id == 'doc':
            doc_scores[doc_id] = cos_similarities[idx]
    most_relevant_doc = max(doc_scores, key=doc_scores.get)
    return most_relevant_doc

def get_top_passages_from_doc(doc_passages, doc_id, vectorizer, query, top_n=2):
    passages = doc_passages[doc_id]
    scores = []

    for passage_id, text in passages:
        passage_vector = vectorizer.transform([text])
        score = cosine_similarity(passage_vector, vectorizer.transform([preprocess(query)])).flatten()[0]
        scores.append((score, passage_id, text))

    scores.sort(reverse=True, key=lambda x: x[0])
    return scores[:top_n]

In [60]:
def print_results(doc_id, top_passages):
    print(f"Most Relevant Document ID: {doc_id}")
    print("Top Passages:")
    for score, passage_id, text in top_passages:
        print(f"Passage ID: {passage_id}, Score: {score:.2f}")
        print(text)
        print()

In [61]:
def main():
    filepath = 'WikiPassageQA/document_passages.json'
    data = load_data(filepath)
    vectorizer, tfidf_matrix, doc_names, doc_passages = prepare_vsm(data)
    query = "What is the structure of Australia’s members of parliament?"
    most_relevant_doc = get_most_relevant_document(vectorizer, tfidf_matrix, doc_names, query)
    top_passages = get_top_passages_from_doc(doc_passages, most_relevant_doc, vectorizer, query, top_n=2)
    print_results(most_relevant_doc, top_passages)
main()

Most Relevant Document ID: 400
Top Passages:
Passage ID: 15, Score: 0.41
there are 123 members of parliament in total. they are also alternatively called member of the national assembly. parliamentary elections are traditionally held every five years with no term limits imposed. the 25 provinces of cambodia are represented by the members of parliament in the national assembly. a constituency may have more than one mp, depending on the population. a member of parliament is a member of either of the two chambers of the parliament of the czech republic, although the term members of parliament of the czech republic is commonly referred to deputies of the parliament of the czech republic who are members of the lower house of the parliament, chamber of deputies.

Passage ID: 0, Score: 0.39
a member of parliament is the representative of the voters to a parliament. in many countries with bicameral parliaments, this category includes specifically members of the lower house, as upper houses oft