In [2]:
#Importing all the necessary libraries
import re
from datasets import load_dataset
from neo4j import GraphDatabase
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import math
from sklearn.model_selection import train_test_split

In [None]:
# Load pre-trained model and tokenizer 
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")

In [1]:
#Creating the Medical Knowledge Graph
class MedicalKnowledgeGraph:
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def create_entity(self, entity_type, name):
        with self.driver.session() as session:
            session.run("MERGE (a:" + entity_type + " {name: $name})", name=name)

    def create_relationship(self, entity1_type, entity1_name, relation, entity2_type, entity2_name):
        with self.driver.session() as session:
            session.run("""
                MATCH (a:""" + entity1_type + """ {name: $entity1_name})
                MATCH (b:""" + entity2_type + """ {name: $entity2_name})
                MERGE (a)-[r:""" + relation + """]->(b)
                """, entity1_name=entity1_name, entity2_name=entity2_name)
#Adding the QA pairs to the graph
    def add_qa_pair(self, question, answer):
        entities = extract_entities(question)
        q_embedding = get_bert_embedding(question).tolist()
        a_embedding = get_bert_embedding(answer).tolist()
        
        with self.driver.session() as session:
            session.run("""
                CREATE (q:Question {text: $question, embedding: $q_embedding})
                CREATE (a:Answer {text: $answer, embedding: $a_embedding})
                CREATE (q)-[:HAS_ANSWER]->(a)
            """, question=question, answer=answer, q_embedding=q_embedding, a_embedding=a_embedding)
            
            for entity, entity_type in entities:
                session.run("""
                    MATCH (q:Question {text: $question})
                    MERGE (e:Entity {name: $entity, type: $entity_type})
                    CREATE (q)-[:CONTAINS]->(e)
                """, question=question, entity=entity, entity_type=entity_type)
#Getting the answer from the graph
    def get_answer(self, question):
        q_embedding = get_bert_embedding(question)
        
        with self.driver.session() as session:
            result = session.run("""
                MATCH (q:Question)-[:HAS_ANSWER]->(a:Answer)
                RETURN q.embedding AS q_embedding, a.text AS answer
            """)
            
            embeddings = []
            answers = []
            for record in result:
                embeddings.append(record["q_embedding"])
                answers.append(record["answer"])
            
            if not embeddings:
                return "No answer found."
            
            similarities = cosine_similarity([q_embedding], embeddings)[0]
            most_similar_index = np.argmax(similarities)
            
            return answers[most_similar_index]
#Getting the BERT embedding of the text
def get_bert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
#Simple tokenization of the text
def simple_tokenize(text):
    return re.findall(r'\b\w+\b', text.lower())
#Extracting the entities from the text
def extract_entities(text):
    words = simple_tokenize(text)
    tfidf = TfidfVectorizer().fit_transform([text])
    important_words = [word for word, score in sorted(zip(words, tfidf.toarray()[0]), key=lambda x: x[1], reverse=True)[:10]]
    return [(word, "KEYWORD") for word in important_words]
#Calculating the BLEU score
def calculate_bleu(reference, candidate):
    ref_tokens = simple_tokenize(reference)
    cand_tokens = simple_tokenize(candidate)
    
    # If either reference or candidate is empty, return 0
    if not ref_tokens or not cand_tokens:
        return 0.0
    
    # Calculate n-gram precisions
    max_n = min(4, len(ref_tokens), len(cand_tokens))
    if max_n == 0:  # Additional check
        return 0.0
        
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = set(zip(*[ref_tokens[i:] for i in range(n)]))
        cand_ngrams = list(zip(*[cand_tokens[i:] for i in range(n)]))
        
        # If no n-grams of this length can be formed, skip
        if not cand_ngrams:
            continue
            
        matches = sum(1 for ngram in cand_ngrams if ngram in ref_ngrams)
        precisions.append(matches / len(cand_ngrams) if matches > 0 else 1e-10)  # Small constant instead of 0
    
    # If no precisions were calculated, return 0
    if not precisions:
        return 0.0
    
    # Calculate brevity penalty
    bp = min(1, len(cand_tokens) / len(ref_tokens))
    
    # Calculate final score
    try:
        s = sum(math.log(p) for p in precisions) / len(precisions)
        return bp * math.exp(s)
    except (ValueError, ZeroDivisionError):
        return 0.0
#Processing the batch of questions
def process_batch(graph, batch_df, is_training=True):
    """Process a batch of questions"""
    for index, row in batch_df.iterrows():
        try:
            if is_training:
                graph.add_qa_pair(row['question'], row['answer'])
            else:
                return answer_question(graph, row['question'], row['answer'])
        except Exception as e:
            print(f"Error processing {'training' if is_training else 'test'} question {index}: {str(e)}")
            if not is_training:
                return "Error", 0.0
#Answering the question
def answer_question(graph, question, ground_truth):
    try:
        answer = graph.get_answer(question)
        bleu_score = calculate_bleu(ground_truth, answer)
        return answer, bleu_score
    except Exception as e:
        print(f"Error processing question: {str(e)}")
        return "Error generating answer", 0.0
#Loading and preprocessing the data
def load_and_preprocess_data():
    dataset = load_dataset("GBaker/MedQA-USMLE-4-options")
    def preprocess(data):
        return {
            'question': data['question'],
            'answer': data['answer'],
            'options': data['options'],
            'meta_info': data['meta_info']
        }
    preprocessed_data = dataset.map(preprocess)
    df = pd.DataFrame(preprocessed_data['train'])
    
    # Take first 10000 samples
    SAMPLE_SIZE = 1000
    df = df.head(SAMPLE_SIZE)
    return df
#Main function
def main():
    # Constants
    BATCH_SIZE = 100
    TOTAL_SAMPLES = 1000
    TEST_SIZE = 0.2
    
    print("Loading and preprocessing dataset...")
    df = load_and_preprocess_data()
    print(f"Dataset loaded. Total samples: {len(df)}")

    # Split data
    train_df, test_df = train_test_split(df, test_size=TEST_SIZE, random_state=42)
    print(f"Training set size: {len(train_df)}, Test set size: {len(test_df)}")

    # Initialize Neo4j connection
    graph = MedicalKnowledgeGraph("bolt://localhost:7687", "neo4j", "123456789")
    print("Connected to Neo4j database.")

    # Process training data in batches
    print("\nPopulating knowledge graph with training data...")
    num_batches = len(train_df) // BATCH_SIZE + (1 if len(train_df) % BATCH_SIZE != 0 else 0)
    
    for batch_num in range(num_batches):
        start_idx = batch_num * BATCH_SIZE
        end_idx = min((batch_num + 1) * BATCH_SIZE, len(train_df))
        batch_df = train_df.iloc[start_idx:end_idx]
        
        print(f"Processing training batch {batch_num + 1}/{num_batches}")
        process_batch(graph, batch_df, is_training=True)

    # Evaluate on test data
    print("\nEvaluating system on test data...")
    total_bleu_score = 0
    processed_count = 0
    results = []

    for index, row in test_df.iterrows():
        try:
            question = row['question']
            ground_truth = row['answer']
            generated_answer, bleu_score = answer_question(graph, question, ground_truth)
            
            results.append({
                'index': index,
                'question': question,
                'generated_answer': generated_answer,
                'ground_truth': ground_truth,
                'bleu_score': bleu_score
            })
            
            total_bleu_score += bleu_score
            processed_count += 1

            # Print progress every 100 questions
            if processed_count % 100 == 0:
                print(f"Processed {processed_count}/{len(test_df)} test questions")
                
        except Exception as e:
            print(f"Error processing test question {index}: {str(e)}")
            continue

    # Calculate and display results
    average_bleu_score = total_bleu_score / processed_count if processed_count > 0 else 0
    
    print("\nEvaluation Results:")
    print(f"Total questions processed: {processed_count}")
    print(f"Average BLEU Score: {average_bleu_score:.4f}")

    # Display sample results
    print("\nSample Results (first 10 questions):")
    for result in results[:10]:
        print(f"\nQuestion {result['index']}:")
        print(f"Question: {result['question']}")
        print(f"Generated Answer: {result['generated_answer']}")
        print(f"Ground Truth: {result['ground_truth']}")
        print(f"BLEU Score: {result['bleu_score']:.4f}")

    # Save results to file
    try:
        with open('evaluation_results.txt', 'w') as f:
            f.write(f"Evaluation Results:\n")
            f.write(f"Total questions processed: {processed_count}\n")
            f.write(f"Average BLEU Score: {average_bleu_score:.4f}\n\n")
            
            f.write("Detailed Results:\n")
            for result in results:
                f.write(f"\nQuestion {result['index']}:\n")
                f.write(f"Question: {result['question']}\n")
                f.write(f"Generated Answer: {result['generated_answer']}\n")
                f.write(f"Ground Truth: {result['ground_truth']}\n")
                f.write(f"BLEU Score: {result['bleu_score']:.4f}\n")
        print("\nResults saved to evaluation_results.txt")
    except Exception as e:
        print(f"Error saving results to file: {str(e)}")

    # Cleanup
    graph.close()
    print("\nKnowledge graph construction and evaluation completed.")

if __name__ == "__main__":
    main()

Loading and preprocessing dataset...
Dataset loaded. Total samples: 1000
Training set size: 800, Test set size: 200
Connected to Neo4j database.

Populating knowledge graph with training data...
Processing training batch 1/8
Processing training batch 2/8
Processing training batch 3/8
Processing training batch 4/8
Processing training batch 5/8
Processing training batch 6/8
Processing training batch 7/8
Processing training batch 8/8

Evaluating system on test data...
Processed 100/200 test questions
Processed 200/200 test questions

Evaluation Results:
Total questions processed: 200
Average BLEU Score: 0.8650

Sample Results (first 10 questions):

Question 521:
Question: An investigator is studying the mechanism of HIV infection in cells obtained from a human donor. The effect of a drug that impairs viral fusion and entry is being evaluated. This drug acts on a protein that is cleaved off of a larger glycosylated protein in the endoplasmic reticulum of the host cell. The protein that is 