In [1]:
import pathlib, os, torch, pickle, time, json
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['HF_HOME'] = str(pathlib.Path("~/scratch-llm/storage/cache/huggingface/").expanduser().absolute()) # '/scratch-llm/storage/cache/'
# os.environ["TRANSFORMERS_CACHE"] = "~/scratch-llm/storage/models/"

import numpy as np

from transformers import AutoTokenizer
from nebulagraph_lite import nebulagraph_let as ng_let
from llama_index.graph_stores.nebula import NebulaPropertyGraphStore

from llama_index.core import Settings
from llama_index.core.schema import TextNode
from llama_index.core.prompts import PromptTemplate
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.vector_stores.simple import SimpleVectorStoreData, SimpleVectorStore, VectorStoreQuery
from llama_index.core.vector_stores.types import MetadataFilters, FilterOperator

from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from typing import List
from numpy import dot
from numpy.linalg import norm
from pydantic import BaseModel, Field
from llama_index.core.output_parsers import PydanticOutputParser

# NebulaGraph conexion

In [None]:
# # load NebulaGraph JupyterNotebook extension
# !udocker pull vesoft/nebula-metad:v3
# !udocker create --name=nebula-metad vesoft/nebula-metad:v3
# !udocker setup --execmode=F1 nebula-metad
# time.sleep(5)  # wait for the container to be ready
# !udocker pull vesoft/nebula-storaged:v3
# !udocker create --name=nebula-storaged vesoft/nebula-storaged:v3
# !udocker setup --execmode=F1 nebula-storaged
# time.sleep(5)  # wait for the container to be ready
# !udocker pull vesoft/nebula-graphd:v3
# !udocker create --name=nebula-graphd vesoft/nebula-graphd:v3
# !udocker setup --execmode=F1 nebula-graphd


n = ng_let(in_container=True, debug=True)  # Enable debug for more info
n.start() # Takes around 5 mins

In [3]:
%reload_ext ngql
%ngql --address 127.0.0.1 --port 9669 --user root --password nebula

[1;3;38;2;0;135;107m[OK] Connection Pool Created[0m


Unnamed: 0,Name
0,PrimeKG
1,basketballplayer


# Vector + Graph store

## SimpleVectorStore:

In [4]:
# Load the actual data into all_nodes_embeddded
with open(os.path.expanduser('~/scratch-llm/storage/nodes/all_nodes_all-mpnet-base-v2.pkl'), 'rb') as f:
    all_nodes_embedded: List[TextNode] = pickle.load(f)
# Create dictionaries from the nodes
embedding_dict = {node.id_: node.get_embedding() for node in all_nodes_embedded}
text_id_to_ref_doc_id = {node.id_: node.ref_doc_id or "None" for node in all_nodes_embedded}
metadata_dict = {node.id_: node.metadata for node in all_nodes_embedded}

# Initialize the SimpleVectorStore with the dictionaries
vector_store = SimpleVectorStore(
    data = SimpleVectorStoreData(
        embedding_dict=embedding_dict,
        text_id_to_ref_doc_id=text_id_to_ref_doc_id,
        metadata_dict=metadata_dict,
    ),
    stores_text=True
)

## NebulaPropertyGraphStore

In [None]:
graph_store = NebulaPropertyGraphStore(
    space = "PrimeKG",
    username = "root",
    password = "nebula",
    url = "nebula://localhost:9669",
    props_schema= """`node_index` STRING, `node_type` STRING, `node_id` STRING, `node_name` STRING, 
        `node_source` STRING, `mondo_id` STRING, `mondo_name` STRING, `group_id_bert` STRING, 
        `group_name_bert` STRING, `orphanet_prevalence` STRING, `display_relation` STRING """,
)

# LLM

## Llama-3.2-3B-Instruct

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left", device_map="auto")    
if tokenizer.pad_token_id is None: #no <pad> token previously defined, only eos_token
    tokenizer.pad_token = "<|end_of_text|>"
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)


llm = HuggingFaceLLM(
    model_name="meta-llama/Llama-3.2-3B-Instruct",
    context_window=8192,
    max_new_tokens=3048,
    generate_kwargs={
        "temperature": 0.10, 
        "do_sample": True,
        "pad_token_id": tokenizer.pad_token_id,
        "top_k": 10, 
        "top_p": 0.9,
        # "repetition_penalty": 0.9,  # Added to reduce repetition
        # "no_repeat_ngram_size": 3,  # Prevents repetition of n-grams
    },
    model_kwargs={
        "torch_dtype": torch.float16,
    },
    tokenizer=tokenizer,
    # device_map="auto",  # Automatically offload layers to CPU if GPU memory is insufficient
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    stopping_ids=[tokenizer.eos_token_id],
    tokenizer_kwargs={"max_length": None},
    is_chat_model=True,
)

Settings.llm = llm
Settings.chunk_size = 1024
Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2") # BAAI/bge-small-en-v1.5 /  m3 / sentence-transformers/all-mpnet-base-v2

# SymptomsMode

In [None]:
phenotype_dict = {
    "key": "node_type",
    "value": "effect/phenotype",
    "operator": FilterOperator.EQ
}

phenotype_filter = MetadataFilters(filters=[phenotype_dict])

In [None]:
class SymptomsMode():
    def __init__(self, vector_store: SimpleVectorStore, graph_store: NebulaPropertyGraphStore):
        self.vector_store = vector_store
        self.graph_store = graph_store

    def retrieve(self, query: List[str]):
        if not isinstance(query, list):
            return None
            
        from collections import defaultdict
        
        disease_counter = {}
        total_symptoms = len(query)
        # print(f"Processing {total_symptoms} symptoms: {query}")
        
        # Collect all diseases and count symptom matches
        for symptom in query:
            query_embedding = Settings.embed_model.get_text_embedding(symptom)
            vector_store_query = VectorStoreQuery(
                query_embedding=query_embedding,
                similarity_top_k=1,
                filters=phenotype_filter,
            )
            individual_results = vector_store.query(vector_store_query)
            
            for node_id, score in zip(individual_results.ids, individual_results.similarities):
                # Get related diseases from graph
                graph_nodes = graph_store.structured_query(
                    """
                    MATCH (e:Node__) WHERE id(e) == $ids
                    MATCH p=(e)-[r:Relation__{label:"disease-phenotype-positive"}]-(t) 
                    RETURN DISTINCT id(t), t.Props__.node_name, t.Chunk__.text
                    """, 
                    param_map={"ids": node_id}
                )
                
                # Process each related disease
                for node in graph_nodes:
                    disease_id = node['id(t)']
                    disease_name = node['t.Props__.node_name']
                    
                    if disease_id not in disease_counter:
                        disease_counter[disease_id] = {
                            'index': disease_id,
                            'name': disease_name,
                            'count': 1,
                            'symptoms': [symptom]
                        }
                    else:
                        disease_counter[disease_id]['count'] += 1
                        if symptom not in disease_counter[disease_id]['symptoms']:
                            disease_counter[disease_id]['symptoms'].append(symptom)
        
        if not disease_counter:
            print("No diseases found matching any symptoms.")
            return {
                'top_diseases': {},
                'top_match_counts': [],
                'grouped_diseases': {},
                'total_symptoms': total_symptoms,
                'top_diseases_list': []
            }
        
        # Get top 2 match counts
        all_match_counts = sorted(set(data['count'] for data in disease_counter.values()), reverse=True)
        top_match_counts = all_match_counts[:2]
        
        # Filter diseases with top 2 match counts
        top_diseases = {
            disease_id: data for disease_id, data in disease_counter.items()
            if data['count'] in top_match_counts
        }
        
        # Group and sort diseases by match count
        grouped_diseases = defaultdict(list)
        for disease_id, data in top_diseases.items():
            grouped_diseases[data['count']].append((disease_id, data))
        
        for count in grouped_diseases:
            grouped_diseases[count].sort(key=lambda x: x[1]['name'])
        
        # Display results
        print(f"\n== Diseases with top 2 symptom match counts ==")
        all_top_diseases_list = []
        for count in sorted(grouped_diseases.keys(), reverse=True):
            print(f"\n--- Diseases with {count}/{total_symptoms} symptom matches ---")
            for disease_id, data in grouped_diseases[count]:
                print(f"ID: {data['index']} | Disease: {data['name']} | Matches: {data['count']}/{total_symptoms}")
                all_top_diseases_list.append(data['name'])
        
        return {
            'top_diseases': top_diseases,
            'top_match_counts': top_match_counts,
            'grouped_diseases': dict(grouped_diseases),
            'total_symptoms': total_symptoms,
            'top_diseases_list': all_top_diseases_list
        }

## Prompt templates

In [None]:
no_rag_template = """
You are a medical knowledge assistant specializing in rare diseases. Your task is to provide a differential diagnosis for the following list of symptoms.
List of symptoms: {query_str}

CRITICAL INSTRUCTIONS:
1. Use the information from the context and your own knowledge to provide a comprehensive answer.
2. Return maximum the 10 most relevant diseases, ordered by relevance.
3. Use medical terminology to refer to the diseases, without abreviations.
4. Return EXACTLY this JSON format:

Always format your response as a VALID JSON:
    {
        "symptoms": {query_str},
        "differential_diagnosis": [
            "disease1",
            "disease2",
            ... and so on
        ]
    }

    Do NOT use nested objects. Use exactly "disease" and "symptoms" as shown.
"""

rag_template = """
You are a medical knowledge assistant specializing in rare diseases. Your task is to provide a differential diagnosis for the following list of symptoms.
List of symptoms: {query_str}

Use the following candidate diseases to guide your answer: {text_chunks}

CRITICAL INSTRUCTIONS:
1. Use the information from the context and your own knowledge to provide a comprehensive differential diagnosis.
2. Return maximum the 10 most relevant diseases, ordered by relevance.
3. Use medical terminology to refer to the diseases, without abreviations.
4. Return EXACTLY this JSON format:

Always format your response as a VALID JSON:
    {
        "symptoms": {query_str},
        "differential_diagnosis": [
            "disease1",
            "disease2",
            ... and so on
        ]
    }

    Do NOT use nested objects. Use exactly "disease" and "symptoms" as shown.
"""

## Chat

In [None]:
user = ["Proximal muscle weakness",
        "Delayed ability to walk",
        "Poor head control",
        "Talipes",
        "Muscular dystrophy",
        "Loss of ambulation",
        "Tube feeding",
        "Paroxysmal atrial tachycardia",
        "Hypotonia",
        "Ventricular tachycardia",
        "Delayed ability to roll over",
        "Decreased fetal movement",
        "Neck muscle weakness",
        "Muscle fiber atrophy",
        "Axial muscle weakness",
        "Respiratory insufficiency due to muscle weakness",
        "Distal muscle weakness"
]
context = SymptomsMode(vector_store, graph_store).retrieve(user)
prompt_template = PromptTemplate(rag_template)

try:
    prompt = prompt_template.format(
        query_str=", ".join(user), 
        text_chunks=", ".join([chunk for chunk in context['top_diseases_list']]))
    
    # print(f"\nTemplate: {prompt}")
    response = llm.chat([ChatMessage(role="user", content=prompt)])
    response_text = response.message.content if hasattr(response, 'message') else str(response)
    display(Markdown(response_text))
except ValueError as e:
    try: 
        summarizer = TreeSummarize(verbose=True, llm=llm, summary_template=prompt_template)
        response = summarizer.get_response(
            query_str=", ".join(user),
            text_chunks=", ".join([chunk for chunk in context['top_diseases_list']])
        )
        display(Markdown(response))
    except Exception as e:
        print(f"Error occurred while summarizing: {e}")
        display(Markdown(response))


print(f"\n\n == no RAG response ==")
template= PromptTemplate(no_rag_template)
prompt=template.format(query_str=", ".join(user))
response = llm.chat([ChatMessage(role="user", content=prompt)])

response_text = response.message.content if hasattr(response, 'message') else str(response)
display(Markdown(response_text))


Processing 17 symptoms: ['Proximal muscle weakness', 'Delayed ability to walk', 'Poor head control', 'Talipes', 'Muscular dystrophy', 'Loss of ambulation', 'Tube feeding', 'Paroxysmal atrial tachycardia', 'Hypotonia', 'Ventricular tachycardia', 'Delayed ability to roll over', 'Decreased fetal movement', 'Neck muscle weakness', 'Muscle fiber atrophy', 'Axial muscle weakness', 'Respiratory insufficiency due to muscle weakness', 'Distal muscle weakness']

=== Diseases with top 2 symptom match counts ===

--- Diseases with 9/17 symptom matches ---
ID: 27265 | Disease: congenital myasthenic syndrome | Matches: 9/17
ID: 27315 | Disease: limb-girdle muscular dystrophy | Matches: 9/17

--- Diseases with 7/17 symptom matches ---
ID: 30345 | Disease: Bethlem myopathy | Matches: 7/17
ID: 31690 | Disease: congenital muscular dystrophy due to LMNA mutation | Matches: 7/17
ID: 27294 | Disease: nemaline myopathy | Matches: 7/17
1 text chunks after repacking

RESPONSE OK: {"symptoms":"Proximal muscle 

In [21]:
# load the phenopackets data
output_file = os.path.expanduser('~/scratch-llm/storage/phenopackets/phenopacket_data.json')
with open(output_file, 'r') as f:
    phenopackets = json.load(f)