In [1]:
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from nltk.tokenize import sent_tokenize


def get_sentences_from_document(filepath):
    with open(filepath, 'r') as file:
        document = file.read()
    sentences = sent_tokenize(document)
    return sentences


def get_document_embedding(sentences):
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    sentence_embeddings = model.encode(sentences)
    document_embedding = np.mean(sentence_embeddings, axis=0)
    return document_embedding


def rank_documents(query_embedding, document_embeddings, documents, k):
    similarities = cosine_similarity(query_embedding.reshape(1, -1), document_embeddings)[0]
    ranked_indices = np.argsort(-similarities)[:k]
    ranked_documents = [documents[i] for i in ranked_indices]
    return ranked_documents


def process_query_case(query_case_filepath, corpus_folder, k, batch_size=50):
    query_sentences = get_sentences_from_document(query_case_filepath)
    query_embedding = get_document_embedding(query_sentences)

    document_embeddings = []
    documents = []

    for document_file in os.listdir(corpus_folder):
        document_filepath = os.path.join(corpus_folder, document_file)
        document_sentences = get_sentences_from_document(document_filepath)
        document_embedding = get_document_embedding(document_sentences)
        document_embeddings.append(document_embedding)
        documents.append(document_file)

        if len(document_embeddings) == batch_size:
            document_embeddings = np.vstack(document_embeddings)
            ranked_documents = rank_documents(query_embedding, document_embeddings, documents, k)
            document_embeddings = []
            documents = []

            yield ranked_documents

    if len(document_embeddings) > 0:
        document_embeddings = np.vstack(document_embeddings)
        ranked_documents = rank_documents(query_embedding, document_embeddings, documents, k)

        yield ranked_documents


def main():
    query_case_filepath = r"C:\Users\This PC\Desktop\Task_1\Test_catches\case_100_catchwords.txt"
    corpus_folder = r"C:\Users\This PC\Desktop\Task_1\Test_catches"
    k = 100  # Number of top-ranked documents to retrieve
    batch_size = 50  # Batch size for processing documents

    batch_generator = process_query_case(query_case_filepath, corpus_folder, k, batch_size)

    for i, ranked_documents in enumerate(batch_generator):
        print(f"Batch {i+1} - Ranked Documents:")
        for j, document in enumerate(ranked_documents):
            print(f"Rank {j+1}: {document}")
        print()

if __name__ == '__main__':
    main()


Batch 1 - Ranked Documents:
Rank 1: case_100_catchwords.txt
Rank 2: case_107_catchwords.txt
Rank 3: case_141_catchwords.txt
Rank 4: case_102_catchwords.txt
Rank 5: case_131_catchwords.txt
Rank 6: case_114_catchwords.txt
Rank 7: case_143_catchwords.txt
Rank 8: case_106_catchwords.txt
Rank 9: case_145_catchwords.txt
Rank 10: case_129_catchwords.txt
Rank 11: case_127_catchwords.txt
Rank 12: case_115_catchwords.txt
Rank 13: case_110_catchwords.txt
Rank 14: case_137_catchwords.txt
Rank 15: case_146_catchwords.txt
Rank 16: case_132_catchwords.txt
Rank 17: case_126_catchwords.txt
Rank 18: case_147_catchwords.txt
Rank 19: case_101_catchwords.txt
Rank 20: case_120_catchwords.txt
Rank 21: case_135_catchwords.txt
Rank 22: case_118_catchwords.txt
Rank 23: case_134_catchwords.txt
Rank 24: case_148_catchwords.txt
Rank 25: case_130_catchwords.txt
Rank 26: case_133_catchwords.txt
Rank 27: case_136_catchwords.txt
Rank 28: case_109_catchwords.txt
Rank 29: case_105_catchwords.txt
Rank 30: case_121_catchw

Batch 6 - Ranked Documents:
Rank 1: case_364_catchwords.txt
Rank 2: case_382_catchwords.txt
Rank 3: case_359_catchwords.txt
Rank 4: case_369_catchwords.txt
Rank 5: case_379_catchwords.txt
Rank 6: case_355_catchwords.txt
Rank 7: case_358_catchwords.txt
Rank 8: case_380_catchwords.txt
Rank 9: case_350_catchwords.txt
Rank 10: case_376_catchwords.txt
Rank 11: case_365_catchwords.txt
Rank 12: case_373_catchwords.txt
Rank 13: case_367_catchwords.txt
Rank 14: case_353_catchwords.txt
Rank 15: case_384_catchwords.txt
Rank 16: case_370_catchwords.txt
Rank 17: case_351_catchwords.txt
Rank 18: case_385_catchwords.txt
Rank 19: case_360_catchwords.txt
Rank 20: case_362_catchwords.txt
Rank 21: case_388_catchwords.txt
Rank 22: case_377_catchwords.txt
Rank 23: case_366_catchwords.txt
Rank 24: case_391_catchwords.txt
Rank 25: case_375_catchwords.txt
Rank 26: case_371_catchwords.txt
Rank 27: case_386_catchwords.txt
Rank 28: case_361_catchwords.txt
Rank 29: case_378_catchwords.txt
Rank 30: case_383_catchw