In [10]:
!pip install nltk

Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting regex>=2021.8.3 (from nltk)
  Downloading regex-2024.9.11-cp38-cp38-win_amd64.whl.metadata (41 kB)
Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
   ---------------------------------------- 0.0/1.5 MB ? eta -:--:--
   ---------------------------------------- 1.5/1.5 MB 20.1 MB/s eta 0:00:00
Downloading regex-2024.9.11-cp38-cp38-win_amd64.whl (274 kB)
Installing collected packages: regex, nltk
Successfully installed nltk-3.9.1 regex-2024.9.11


In [None]:
import nltk
nltk.download()

showing info https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/index.xml


In [4]:
pip install transformers torch scikit-learn nltk

Collecting transformers
  Downloading transformers-4.45.2-py3-none-any.whl.metadata (44 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Downloading safetensors-0.4.5-cp38-none-win_amd64.whl.metadata (3.9 kB)
Collecting tokenizers<0.21,>=0.20 (from transformers)
  Downloading tokenizers-0.20.1-cp38-none-win_amd64.whl.metadata (6.9 kB)
Downloading transformers-4.45.2-py3-none-any.whl (9.9 MB)
   ---------------------------------------- 0.0/9.9 MB ? eta -:--:--
   --------------------------------- ------ 8.4/9.9 MB 47.2 MB/s eta 0:00:01
   ---------------------------------------- 9.9/9.9 MB 43.9 MB/s eta 0:00:00
Downloading safetensors-0.4.5-cp38-none-win_amd64.whl (286 kB)
Downloading tokenizers-0.20.1-cp38-none-win_amd64.whl (2.4 MB)
   ---------------------------------------- 0.0/2.4 MB ? eta -:--:--
   ---------------------------------------- 2.4/2.4 MB 34.2 MB/s eta 0:00:00
Installing collected packages: safetensors, tokenizers, transformers
Successfully installed safetensors

In [20]:
import re
from datasets import load_dataset
from neo4j import GraphDatabase
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import math

# Load pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")

In [21]:
def get_bert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

In [22]:
def simple_tokenize(text):
    return re.findall(r'\b\w+\b', text.lower())


In [23]:
def extract_entities(text):
    words = simple_tokenize(text)
    tfidf = TfidfVectorizer().fit_transform([text])
    important_words = [word for word, score in sorted(zip(words, tfidf.toarray()[0]), key=lambda x: x[1], reverse=True)[:10]]
    return [(word, "KEYWORD") for word in important_words]

In [24]:
def load_and_preprocess_data():
    dataset = load_dataset("GBaker/MedQA-USMLE-4-options")
    def preprocess(data):
        return {
            'question': data['question'],
            'answer': data['answer'],
            'options': data['options'],
            'meta_info': data['meta_info']
        }
    preprocessed_data = dataset.map(preprocess)
    df = pd.DataFrame(preprocessed_data['train'])
    return df


In [25]:
class MedicalKnowledgeGraph:
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def create_entity(self, entity_type, name):
        with self.driver.session() as session:
            session.run("MERGE (a:" + entity_type + " {name: $name})", name=name)

    def create_relationship(self, entity1_type, entity1_name, relation, entity2_type, entity2_name):
        with self.driver.session() as session:
            session.run("""
                MATCH (a:""" + entity1_type + """ {name: $entity1_name})
                MATCH (b:""" + entity2_type + """ {name: $entity2_name})
                MERGE (a)-[r:""" + relation + """]->(b)
                """, entity1_name=entity1_name, entity2_name=entity2_name)

    def add_qa_pair(self, question, answer):
        entities = extract_entities(question)
        q_embedding = get_bert_embedding(question).tolist()
        a_embedding = get_bert_embedding(answer).tolist()
        
        with self.driver.session() as session:
            session.run("""
                CREATE (q:Question {text: $question, embedding: $q_embedding})
                CREATE (a:Answer {text: $answer, embedding: $a_embedding})
                CREATE (q)-[:HAS_ANSWER]->(a)
            """, question=question, answer=answer, q_embedding=q_embedding, a_embedding=a_embedding)
            
            for entity, entity_type in entities:
                session.run("""
                    MATCH (q:Question {text: $question})
                    MERGE (e:Entity {name: $entity, type: $entity_type})
                    CREATE (q)-[:CONTAINS]->(e)
                """, question=question, entity=entity, entity_type=entity_type)

    def get_answer(self, question):
        q_embedding = get_bert_embedding(question)
        
        with self.driver.session() as session:
            result = session.run("""
                MATCH (q:Question)-[:HAS_ANSWER]->(a:Answer)
                RETURN q.embedding AS q_embedding, a.text AS answer
            """)
            
            embeddings = []
            answers = []
            for record in result:
                embeddings.append(record["q_embedding"])
                answers.append(record["answer"])
            
            if not embeddings:
                return "No answer found."
            
            similarities = cosine_similarity([q_embedding], embeddings)[0]
            most_similar_index = np.argmax(similarities)
            
            return answers[most_similar_index]

def calculate_bleu(reference, candidate):
    ref_tokens = simple_tokenize(reference)
    cand_tokens = simple_tokenize(candidate)
    
    max_n = min(4, len(ref_tokens), len(cand_tokens))
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = set(zip(*[ref_tokens[i:] for i in range(n)]))
        cand_ngrams = list(zip(*[cand_tokens[i:] for i in range(n)]))
        matches = sum(1 for ngram in cand_ngrams if ngram in ref_ngrams)
        precisions.append(matches / len(cand_ngrams) if cand_ngrams else 0)
    
    bp = min(1, len(cand_tokens) / len(ref_tokens)) if len(ref_tokens) > 0 else 0
    
    if all(p > 0 for p in precisions):
        s = (sum(map(lambda x: math.log(x), precisions)) / len(precisions))
        return bp * math.exp(s)
    else:
        return 0

def answer_question(graph, question, ground_truth):
    answer = graph.get_answer(question)
    bleu_score = calculate_bleu(ground_truth, answer)
    return answer, bleu_score

def main():
    df = load_and_preprocess_data()
    print("Dataset loaded and preprocessed.")

    graph = MedicalKnowledgeGraph("bolt://localhost:7687", "neo4j", "123456789")
    print("Connected to Neo4j database.")

    for index, row in df.iterrows():
        question = row['question']
        answer = row['answer']
        graph.add_qa_pair(question, answer)
        print(f"Processed question {index + 1}")
        if index == 99:  # Process only 100 questions for demonstration
            break

    sample_question = df.iloc[1]['question']
    ground_truth = df.iloc[1]['answer']
    generated_answer, bleu_score = answer_question(graph, sample_question, ground_truth)
    print(f"Question: {sample_question}")
    print(f"Generated Answer: {generated_answer}")
    print(f"Ground Truth: {ground_truth}")
    print(f"BLEU Score: {bleu_score}")

    graph.close()
    print("Knowledge graph construction and evaluation completed.")

if __name__ == "__main__":
    main()


Dataset loaded and preprocessed.
Connected to Neo4j database.
Processed question 1
Processed question 2
Processed question 3
Processed question 4
Processed question 5
Processed question 6
Processed question 7
Processed question 8
Processed question 9
Processed question 10
Processed question 11
Processed question 12
Processed question 13
Processed question 14
Processed question 15
Processed question 16
Processed question 17
Processed question 18
Processed question 19
Processed question 20
Processed question 21
Processed question 22
Processed question 23
Processed question 24
Processed question 25
Processed question 26
Processed question 27
Processed question 28
Processed question 29
Processed question 30
Processed question 31
Processed question 32
Processed question 33
Processed question 34
Processed question 35
Processed question 36
Processed question 37
Processed question 38
Processed question 39
Processed question 40
Processed question 41
Processed question 42
Processed question 43
P