In [7]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel

class SimpleTFIDFRetriever:
    def __init__(self, documents):
        self.vectorizer = TfidfVectorizer(stop_words='english')
        self.documents = documents
        self.doc_vectors = self.vectorizer.fit_transform(documents)
    
    def query(self, text, top_n=5):
        # Vectorize the query
        query_vector = self.vectorizer.transform([text])
        
        # Calculate cosine similarity
        cosine_similarities = linear_kernel(query_vector, self.doc_vectors).flatten()
        
        # Get the top_n document indices based on similarity
        related_docs_indices = cosine_similarities.argsort()[:-top_n-1:-1]
        
        # Return the top_n relevant documents
        return [self.documents[i] for i in related_docs_indices]


In [9]:
import os, json
from transformers import BartForConditionalGeneration, BartTokenizer

# Initialize the BART model and tokenizer
model_name = "facebook/bart-base"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = BartTokenizer.from_pretrained(model_name)

def summarize_text(text):
    # Tokenize the input text
    inputs = tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
    
    # Generate the summarized version of the input text
    summary_ids = model.generate(inputs["input_ids"], num_beams=4, length_penalty=2.0, max_length=250, min_length=50, no_repeat_ngram_size=2)
    
    # Decode and return the summarized text
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

In [12]:
from tqdm import tqdm

documents = []
folder_path = "/home/etien/Documents/EPFLcourses/MA3/Meditron/Guidelines/split_guidelines/cdc_diseases.jsonl"   

for filename in tqdm(os.listdir(folder_path)):
    if filename.endswith(".json"):
        with open(os.path.join(folder_path, filename), 'r') as f:
            data = json.load(f)
            # summarize the data['text'] before appending it to documents
            summarized_text = summarize_text(data['text'])
            documents.append(summarized_text)

100%|██████████| 46/46 [07:41<00:00, 10.03s/it]


In [13]:
retriever = SimpleTFIDFRetriever(documents)
query_text = "What is Cholera?"
print(retriever.query(query_text))

["Cholera | Disease Directory | Travelers' Health | CDC [CDC] [Health] CDC[CDC], CDC (CDC), CDC, CDC(CDC) CDC-CDC[mother and child washing their hands](/travel/images/handwashing-2022.jpg)CDC/CDC- CDC/CHOLERA CHOLERA CHALLENGEChOLera can be a life-threatening disease caused by bacteria called - Vibrio cholerae-. CholerA person can get choleremia from unsafe food or water. This can happen when cholaera bacteria spread from a person into drinking water or the water used to grow food, prepare food. These bacteria can also occur when stool (poop) in sewage gets into the body and contaminates the food supply.Most people who get C-  Most people will have mild or no symptoms. About 1 in 10 people with choroidal choleum will experience severe symptoms and death. Early symptoms (choleroomptoms](https://www.cdc.gov/c-o-chola.html) include the following:  Irritable diarrhea, sometimes described as “rice-water stools�", "Mpox | Disease Directory | Travelers' Health | CDC's Mpox Information System 