In [1]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import numpy as np
import pandas as pd
import re
import time
import evaluate
from evaluate import load
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from textstat import flesch_kincaid_grade




In [3]:
def preprocess_text(text, max_length=512):
    # Remove punctuation and replace them with spaces
    text = re.sub(r'[^\w\s]', ' ', text)
    # Remove redundant spaces and convert the text to lowercase
    text = re.sub(r'\s+', ' ', text).lower().strip()
    # Limit text length to prevent long processing time or memory issues
    if len(text) > max_length:
        text = text[:max_length]
    return text

# Load dataset
df = pd.read_csv("filtered_questions_utf8.csv")

# Extract questions and answers
questions_long = [preprocess_text(q) for q in df['Patient']]
answers_long = [preprocess_text(a) for a in df['Doctor']]

questions_short = [preprocess_text(q) for q in df['Description']]
answers_short = [preprocess_text(a) for a in df['short answer']]

In [4]:
# Load BioBERT
biobert_tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-large-cased-v1.1")
biobert_model = AutoModelForQuestionAnswering.from_pretrained("dmis-lab/biobert-large-cased-v1.1")

config.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


vocab.txt:   0%|          | 0.00/467k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at dmis-lab/biobert-large-cased-v1.1 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# Load BioGPT-Large
biogpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")
biogpt_model = AutoModelForCausalLM.from_pretrained("microsoft/BioGPT-Large")

In [5]:
# Load Medical-Llama3-8B
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.float16,)
llama_model = AutoModelForCausalLM.from_pretrained("ruslanmv/Medical-Llama3-8B",quantization_config=bnb_config, trust_remote_code=True,use_cache=False,device_map='auto')
llama_tokenizer = AutoTokenizer.from_pretrained("ruslanmv/Medical-Llama3-8B", trust_remote_code=True)
llama_tokenizer.pad_token = llama_tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
# Load the text encoders for semantic similarity
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
scincl_model = SentenceTransformer('malteos/scincl')

# Load perplexity metric
perplexity_metric = load("perplexity", module_type="metric")

In [6]:
def evaluate_model(predictions, references):

    # Exact match
    exact_match = np.mean([1 if p == r else 0 for p, r in zip(predictions, references)])

    # Helper function for extracting character-level trigrams
    def extract_trigrams(text):
        return set([text[i:i+3] for i in range(len(text) - 2)])

    # Define Jaccard similarity calculation (including α and β and harmonic mean)
    def calculate_jaccard_score(pred, ref):
        # Extract trigrams from prediction and reference
        pred_trigrams = extract_trigrams(pred)
        ref_trigrams = extract_trigrams(ref)

        # Calculate alpha (Jaccard index)
        alpha = len(pred_trigrams.intersection(ref_trigrams)) / len(pred_trigrams.union(ref_trigrams)) if pred_trigrams.union(ref_trigrams) else 0

        # Calculate beta (containment measure)
        beta = len(pred_trigrams.intersection(ref_trigrams)) / min(len(pred_trigrams), len(ref_trigrams)) if min(len(pred_trigrams), len(ref_trigrams)) > 0 else 0

        # Calculate harmonic mean Jaccard similarity
        jaccard_score = (2 * alpha * beta) / (alpha + beta) if (alpha + beta) > 0 else 0
        return jaccard_score

    # Calculate Jaccard similarity scores
    jaccard_scores = [calculate_jaccard_score(pred, ref) for pred, ref in zip(predictions, references)]
    jaccard_similarity_score = sum(jaccard_scores) / len(jaccard_scores)

    # Helper function for calculating semantic similarity
    def compute_cosine_similarity(embeddings1, embeddings2):
        similarities = [cosine_similarity([e1], [e2])[0][0] for e1, e2 in zip(embeddings1, embeddings2)]
        return np.mean(similarities)

    # Calculate semantic similarity (using preloaded SBERT and SciNCL embedding models)
    sbert_embeddings_pred = sbert_model.encode(predictions)
    sbert_embeddings_ref = sbert_model.encode(references)
    scincl_embeddings_pred = scincl_model.encode(predictions)
    scincl_embeddings_ref = scincl_model.encode(references)

    sbert_similarity = compute_cosine_similarity(sbert_embeddings_pred, sbert_embeddings_ref)
    scincl_similarity = compute_cosine_similarity(scincl_embeddings_pred, scincl_embeddings_ref)

    # Calculate perplexity
    perplexity_results = perplexity_metric.compute(predictions=[p for p in predictions if p.strip() != ''], model_id='gpt2')
    average_perplexity_score = perplexity_results['mean_perplexity']

    # Calculate readability score
    readability_scores = [flesch_kincaid_grade(p) for p in predictions]
    average_readability_score = sum(readability_scores) / len(readability_scores)

    return {
        "Exact Match": exact_match,
        "Character-level Trigram Jaccard Similarity": jaccard_similarity_score,
        "SBERT Similarity": sbert_similarity,
        "SciNCL Similarity": scincl_similarity,
        "Average Perplexity": average_perplexity_score,
        "Flesch-Kincaid Readability Score": average_readability_score
    }


In [7]:
def get_biobert_answer(question, max_length=512):
    start_time = time.time()
    inputs = biobert_tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        outputs = biobert_model(**inputs)
        start_index = outputs.start_logits.argmax()
        end_index = outputs.end_logits.argmax()
        answer = biobert_tokenizer.convert_tokens_to_string(
            biobert_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index + 1])
        )
    response_time = time.time() - start_time
    return answer, response_time

# Generate answers using BioBERT and calculate average response time
biobert_answers_long, biobert_time_long = zip(*[get_biobert_answer(q) for q in questions_long])
biobert_avg_time_long = np.mean(biobert_time_long)

biobert_answers_short, biobert_time_short = zip(*[get_biobert_answer(q) for q in questions_short])
biobert_avg_time_short = np.mean(biobert_time_short)

In [8]:
biobert_evaluation_long = evaluate_model(biobert_answers_long, answers_long)
biobert_evaluation_short = evaluate_model(biobert_answers_short, answers_short)

print("For long question: ")
print("BioBERT Evaluation:", biobert_evaluation_long)
print("BioBERT Average Response Time:", biobert_avg_time_long)


print("For short question: ")
print("BioBERT Evaluation:", biobert_evaluation_short)
print("BioBERT Average Response Time:", biobert_avg_time_short)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

For long question: 
BioBERT Evaluation: {'Exact Match': 0.0, 'Character-level Trigram Jaccard Similarity': 0.07365148415518515, 'SBERT Similarity': 0.21539243, 'SciNCL Similarity': 0.8190832, 'Average Perplexity': 7759.960881013137, 'Flesch-Kincaid Readability Score': -0.8948717948717948}
BioBERT Average Response Time: 0.15589275726905236
For short question: 
BioBERT Evaluation: {'Exact Match': 0.0, 'Character-level Trigram Jaccard Similarity': 0.06692054045980415, 'SBERT Similarity': 0.19957225, 'SciNCL Similarity': 0.80776066, 'Average Perplexity': 760.3388841417101, 'Flesch-Kincaid Readability Score': -5.56153846153846}
BioBERT Average Response Time: 0.11660006107428135


In [11]:
from tqdm import tqdm

In [12]:
def get_biogpt_answer(question, max_length=512, max_new_tokens=150):
    start_time = time.time()
    inputs = biogpt_tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
    outputs = biogpt_model.generate(inputs["input_ids"], max_new_tokens=max_new_tokens, num_return_sequences=1)
    answer = biogpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
    response_time = time.time() - start_time
    return answer, response_time

# Generate answers using BioGPT and calculate average response time with progress bars
biogpt_answers_long, biogpt_time_long = zip(*[get_biogpt_answer(q) for q in tqdm(questions_long, desc="Processing long questions")])
biogpt_avg_time_long = np.mean(biogpt_time_long)

biogpt_answers_short, biogpt_time_short = zip(*[get_biogpt_answer(q) for q in tqdm(questions_short, desc="Processing short questions")])
biogpt_avg_time_short = np.mean(biogpt_time_short)

Processing long questions: 100%|██████████| 39/39 [22:28<00:00, 34.58s/it]
Processing short questions: 100%|██████████| 39/39 [06:24<00:00,  9.85s/it]


In [13]:
biogpt_evaluation_long = evaluate_model(biogpt_answers_long, answers_long)
biogpt_evaluation_short = evaluate_model(biogpt_answers_short, answers_short)

print("For long question: ")
print("BioGPT Evaluation:", biogpt_evaluation_long)
print("BioGPT Average Response Time:", biogpt_avg_time_long)

print("For short question: ")
print("BioGPT Evaluation:", biogpt_evaluation_short)
print("BioGPT Average Response Time:", biogpt_avg_time_short)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

For long question: 
BioGPT Evaluation: {'Exact Match': 0.0, 'Character-level Trigram Jaccard Similarity': 0.22446526552432317, 'SBERT Similarity': 0.49925134, 'SciNCL Similarity': 0.8723744, 'Average Perplexity': 51.598384417020355, 'Flesch-Kincaid Readability Score': 15.317948717948717}
BioGPT Average Response Time: 34.58056183350392
For short question: 
BioGPT Evaluation: {'Exact Match': 0.0, 'Character-level Trigram Jaccard Similarity': 0.19079365163406453, 'SBERT Similarity': 0.6421805, 'SciNCL Similarity': 0.8912131, 'Average Perplexity': 156.16183055975497, 'Flesch-Kincaid Readability Score': 10.125641025641023}
BioGPT Average Response Time: 9.852194792185074


In [14]:
def get_llama_answer(question):
    start_time = time.time()
    sys_message = '''
    You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
    provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
    '''
    # Create messages structured for the chat template
    messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]

    chat_template = "{% for message in messages %}{{message.role}}: {{message.content}}{% endfor %}<|im_start|>assistant"

    # Applying chat template
    prompt = llama_tokenizer.apply_chat_template(messages, chat_template=chat_template, tokenize=False, add_generation_prompt=True)
    inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = llama_model.generate(**inputs, max_new_tokens=100, use_cache=True)

    # Extract and return the generated text, removing the prompt
    response_text = llama_tokenizer.batch_decode(outputs)[0].strip()
    answer = response_text.split('<|im_start|>assistant')[-1].strip()
    response_time = time.time() - start_time
    return answer, response_time

# Generate answers using Llama and calculate average response time with progress bars
llama_answers_long, llama_time_long = zip(*[get_llama_answer(q) for q in tqdm(questions_long, desc="Processing long questions")])
llama_avg_time_long = np.mean(llama_time_long)

llama_answers_short, llama_time_short = zip(*[get_llama_answer(q) for q in tqdm(questions_short, desc="Processing short questions")])
llama_avg_time_short = np.mean(llama_time_short)

Processing long questions:   0%|          | 0/39 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
Processing long questions:   3%|▎         | 1/39 [00:43<27:14, 43.01s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing long questions:   5%|▌         | 2/39 [00:52<14:20, 23.25s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing long questions:   8%|▊         | 3/39 [00:59<09:27, 15.77s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing long questions:  10%|█         | 4/39 [01:05<07:02, 12.08s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing long questions:  13%|█▎        | 5/39 [01:12<05:46, 10.20s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Process

In [15]:
llama_evaluation_long = evaluate_model(llama_answers_long, answers_long)
llama_evaluation_short = evaluate_model(llama_answers_short, answers_short)

print("For long question: ")
print("Llama Evaluation:", llama_evaluation_long)
print("Llama Average Response Time:", llama_avg_time_long)

print("For short question: ")
print("Llama Evaluation:", llama_evaluation_short)
print("Llama Average Response Time:", llama_avg_time_short)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

For long question: 
Llama Evaluation: {'Exact Match': 0.0, 'Character-level Trigram Jaccard Similarity': 0.2261874717049809, 'SBERT Similarity': 0.506148, 'SciNCL Similarity': 0.8568915, 'Average Perplexity': 21.43885471881964, 'Flesch-Kincaid Readability Score': 8.61794871794872}
Llama Average Response Time: 8.3434403859652
For short question: 
Llama Evaluation: {'Exact Match': 0.0, 'Character-level Trigram Jaccard Similarity': 0.21522961324648554, 'SBERT Similarity': 0.54820853, 'SciNCL Similarity': 0.8715154, 'Average Perplexity': 28.14072602834457, 'Flesch-Kincaid Readability Score': 9.77948717948718}
Llama Average Response Time: 6.834231718992576
