In [None]:
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer

In [None]:
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import fetch_20newsgroups

class BM25:

    def __init__(self, documents, k1=1.5, b=0.75):
        self.k1 = k1
        self.b = b
        self.documents = documents
        self.N = len(documents)
        self.doc_lengths = [len(doc.split()) for doc in documents]
        self.avgdl = np.mean(self.doc_lengths)
        self.doc_freqs = self._compute_doc_freqs()
    
    def _compute_doc_freqs(self):
        vectorizer = CountVectorizer()
        X = vectorizer.fit_transform(self.documents)
        self.vocabulary = vectorizer.get_feature_names_out()
        doc_freqs = np.array((X > 0).sum(axis=0)).flatten()
        return dict(zip(self.vocabulary, doc_freqs))
    
    def _idf(self, term):
        df = self.doc_freqs.get(term, 0)
        return np.log((self.N - df + 0.5) / (df + 0.5) + 1.0)
    
    def _score(self, query, doc_index):
        score = 0
        doc = self.documents[doc_index].split()
        doc_len = self.doc_lengths[doc_index]
        
        for term in query.split():
            if term in self.doc_freqs:
                idf = self._idf(term)
                tf = doc.count(term)
                score += idf * (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * (doc_len / self.avgdl)))
        
        return score
    
    def search(self, query):
        scores = [self._score(query, i) for i in range(self.N)]
        return np.argsort(scores)[::-1]  # Sort by descending scores

# Load the 20 Newsgroups dataset
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
documents = newsgroups.data

# Create an instance of BM25
bm25 = BM25(documents)

# Query example
query = "space exploration"
results = bm25.search(query)

# Display results
for index in results[:5]:  # Show top 5 results
    print(f"Document: {documents[index][:500]}...")  # Print the first 500 characters of the document
    print(f"Score: {bm25._score(query, index)}")
    print()