<a href="https://colab.research.google.com/github/chapSKor/basicRAGs/blob/main/RAG_medical_multiple_t5_large.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code will hypertune Googles t5-large model to multiple medical libraries to create a more advanced chatbot.

In [1]:
pip install datasets transformers sentence-transformers torch

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [4]:
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from transformers import T5ForConditionalGeneration, T5Tokenizer
import pandas as pd


In [18]:
#medmcqa - Medical Multiple Choice Question Answering dataset.
#Great for general medical knowledge.
#n2c2_2006_deid - A dataset containing de-identified clinical notes.
#Useful for extracting clinical language patterns.
#pubmed_abstracts - A dataset with abstracts from PubMed.
#Contains biomedical research articles.
#medical_dialogue - Dataset with doctor-patient dialogues.
#Useful for medical conversations and context.
#emrqa - Question answering dataset based on electronic medical records.
#Focuses on extracting information from EMRs.
#covid_qa_deepset - COVID-19 related question-answer dataset.
#Specific to recent COVID-19 developments.

# Step 1: Define datasets to load
datasets_to_load = [
    ('medmcqa', 'default'),
    ('bioasq', 'train'),
    ('pubmed_qa', 'pqa_labeled'),
    ('n2c2_2006_deid', None),
    ('pubmed_abstracts', None),
    ('medical_dialogue', None),
    ('emrqa', None),
    ('covid_qa_deepset', None),
    ('clinical_trials', None)
]

corpus = []

# Load multiple datasets and handle any potential errors
for dataset_name, config in datasets_to_load:
    try:
        if config:
            dataset = load_dataset(dataset_name, config, split='train')
        else:
            dataset = load_dataset(dataset_name, split='train')

        # Extract text fields and prioritize non-question fields
        if 'context' in dataset.column_names:
            corpus.extend(dataset['context'])
        elif 'text' in dataset.column_names:
            corpus.extend(dataset['text'])
        elif 'abstract' in dataset.column_names:
            corpus.extend(dataset['abstract'])
        elif 'description' in dataset.column_names:
            corpus.extend(dataset['description'])
        elif 'dialogue' in dataset.column_names:
            corpus.extend(dataset['dialogue'])
        elif 'question' in dataset.column_names:
            corpus.extend(dataset['question'])
        print(f"Loaded dataset: {dataset_name}")

    except Exception as e:
        print(f"Failed to load dataset {dataset_name}: {e}")

# Check if the corpus is empty
if not corpus:
    raise ValueError("No data loaded. Please check the dataset names and configurations.")


Loaded dataset: medmcqa
Failed to load dataset bioasq: Dataset 'bioasq' doesn't exist on the Hub or cannot be accessed.
Loaded dataset: pubmed_qa
Failed to load dataset n2c2_2006_deid: Dataset 'n2c2_2006_deid' doesn't exist on the Hub or cannot be accessed.
Failed to load dataset pubmed_abstracts: Dataset 'pubmed_abstracts' doesn't exist on the Hub or cannot be accessed.
Failed to load dataset medical_dialogue: Dataset 'medical_dialogue' doesn't exist on the Hub or cannot be accessed.
Failed to load dataset emrqa: Dataset 'emrqa' doesn't exist on the Hub or cannot be accessed.
Loaded dataset: covid_qa_deepset
Failed to load dataset clinical_trials: Dataset 'clinical_trials' doesn't exist on the Hub or cannot be accessed.


In [21]:
# Step 2: Initialize the Sentence Transformer for text retrieval
retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
corpus_embeddings = retriever_model.encode(corpus, convert_to_tensor=True)

# Step 3: Load the T5-Large model for text generation
t5_model = T5ForConditionalGeneration.from_pretrained('t5-large')
t5_tokenizer = T5Tokenizer.from_pretrained('t5-large')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t5_model.to(device)

KeyboardInterrupt: 

In [19]:
# Function to retrieve relevant passages from the combined corpus
def retrieve_passages(query, top_k):
    query_embedding = retriever_model.encode(query, convert_to_tensor=True)
    similarities = util.cos_sim(query_embedding, corpus_embeddings)
    top_k_indices = torch.topk(similarities, k=top_k).indices.flatten()
    retrieved_passages = [corpus[idx.item()] for idx in top_k_indices]
    return retrieved_passages

# Function to generate a response using T5-Large
def generate_response(query, retrieved_passages):
    context = " ".join(retrieved_passages)
    input_text = f"question: {query} context: {context}"

    input_ids = t5_tokenizer.encode(input_text, return_tensors='pt', max_length=512, truncation=True).to(device)
    output_ids = t5_model.generate(
        input_ids,
        max_length=500,
        min_length=60,
        num_beams=5,
        early_stopping=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1
    )
    response = t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return response



In [17]:
# Example Usage
query = "What are the current treatments for diabetes?"
retrieved_passages = retrieve_passages(query, top_k=6)
response = generate_response(query, retrieved_passages)

print("\nQuery:", query)
print("\nRetrieved Passages:")
for passage in retrieved_passages:
    print("-", passage)
print("\nGenerated Response:", response)


Query: What are the latest advancements in cancer treatment?

Retrieved Passages:
- Is breast cancer survival improving?
- What killed prostate cancer cells in vitro?
- What are some non-pharmaceutical interventions?
- What is the third most prevalent cancer in females in the United States?
- Does high-dose radiotherapy benefit palliative lung cancer patients?

Generated Response: the third most prevalent cancer in females in the United States? Does high-dose radiotherapy benefit palliative lung cancer patients? Does high-dose radiotherapy benefit palliative lung cancer patients? Does high-dose radiotherapy benefit palliative lung cancer patients? Does high-dose radiotherapy
