In [None]:
import pathlib, os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['HF_HOME'] = str(pathlib.Path("~/scratch-llm/storage/cache/huggingface/").expanduser().absolute()) # '/scratch-llm/storage/cache/'
# os.environ["TRANSFORMERS_CACHE"] = "~/scratch-llm/storage/models/"

import torch, pickle
import numpy as np

from transformers import AutoTokenizer
from nebulagraph_lite import nebulagraph_let as ng_let
from llama_index.graph_stores.nebula import NebulaPropertyGraphStore

from llama_index.core import Settings
from llama_index.core.schema import TextNode
from llama_index.core.prompts import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.vector_stores.simple import SimpleVectorStoreData, SimpleVectorStore, VectorStoreQuery

from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from typing import List
from numpy import dot
from numpy.linalg import norm
from pydantic import BaseModel, Field
from llama_index.core.output_parsers import PydanticOutputParser

# NebulaGraph conexion

In [None]:
# load NebulaGraph JupyterNotebook extension
# !udocker pull vesoft/nebula-metad:v3
# !udocker create --name=nebula-metad vesoft/nebula-metad:v3
# !udocker setup --execmode=F1 nebula-metad
# !udocker pull vesoft/nebula-graphd:v3
# !udocker create --name=nebula-graphd vesoft/nebula-graphd:v3
# !udocker setup --execmode=F1 nebula-graphd
# !udocker pull vesoft/nebula-storaged:v3
# !udocker create --name=nebula-storaged vesoft/nebula-storaged:v3
# !udocker setup --execmode=F1 nebula-storaged


n = ng_let(in_container=True)
n.start() # Takes around 5 mins

In [None]:
%reload_ext ngql
%ngql --address 127.0.0.1 --port 9669 --user root --password nebula

# Vector + Graph store

## SimpleVectorStore:

In [None]:
# Load the actual data into all_nodes_embeddded
with open(os.path.expanduser('~/scratch-llm/storage/nodes/all_nodes_all-mpnet-base-v2.pkl'), 'rb') as f:
    all_nodes_embedded: List[TextNode] = pickle.load(f)
# Create dictionaries from the nodes
embedding_dict = {node.id_: node.get_embedding() for node in all_nodes_embedded}
text_id_to_ref_doc_id = {node.id_: node.ref_doc_id or "None" for node in all_nodes_embedded}
metadata_dict = {node.id_: node.metadata for node in all_nodes_embedded}

# Initialize the SimpleVectorStore with the dictionaries
vector_store = SimpleVectorStore(
    data = SimpleVectorStoreData(
        embedding_dict=embedding_dict,
        text_id_to_ref_doc_id=text_id_to_ref_doc_id,
        metadata_dict=metadata_dict,
    ),
    stores_text=True
)

## NebulaPropertyGraphStore

In [None]:
graph_store = NebulaPropertyGraphStore(
    space = "PrimeKG",
    username = "root",
    password = "nebula",
    url = "nebula://localhost:9669",
    props_schema= """`node_index` STRING, `node_type` STRING, `node_id` STRING, `node_name` STRING, 
        `node_source` STRING, `mondo_id` STRING, `mondo_name` STRING, `group_id_bert` STRING, 
        `group_name_bert` STRING, `orphanet_prevalence` STRING, `display_relation` STRING """,
)

# LLM

## Llama-3.2-3B-Instruct

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left", device_map="auto")    
if tokenizer.pad_token_id is None: #no <pad> token previously defined, only eos_token
    tokenizer.pad_token = "<|end_of_text|>"
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)


llm = HuggingFaceLLM(
    model_name="meta-llama/Llama-3.2-3B-Instruct",
    context_window=8192,
    max_new_tokens=3048,
    generate_kwargs={
        "temperature": 0.10, 
        "do_sample": True,
        "pad_token_id": tokenizer.pad_token_id,
        "top_k": 10, 
        "top_p": 0.9,
        # "repetition_penalty": 0.9,  # Added to reduce repetition
        # "no_repeat_ngram_size": 3,  # Prevents repetition of n-grams
    },
    model_kwargs={
        "torch_dtype": torch.float16,
    },
    tokenizer=tokenizer,
    # device_map="auto",  # Automatically offload layers to CPU if GPU memory is insufficient
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    stopping_ids=[tokenizer.eos_token_id],
    tokenizer_kwargs={"max_length": None},
    is_chat_model=True,
)

Settings.llm = llm
Settings.chunk_size = 1024
Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2") # BAAI/bge-small-en-v1.5 /  m3 / sentence-transformers/all-mpnet-base-v2

# SymptomsMode

In [None]:
from llama_index.core.vector_stores.types import MetadataFilters, FilterOperator

# Example usage
phenotype_dict = {
    "key": "node_type",
    "value": "effect/phenotype",
    "operator": FilterOperator.EQ
}

phenotype_filter = MetadataFilters(filters=[phenotype_dict])

In [None]:
class SymptomsMode():
    def __init__(self, vector_store: SimpleVectorStore, graph_store: NebulaPropertyGraphStore):
        self.vector_store = vector_store
        self.graph_store = graph_store

    def retrieve(self, query: List[str]):
        if not isinstance(query, list):
            return None
            
        from numpy import dot
        from numpy.linalg import norm
        from collections import defaultdict
        
        disease_counter = {}
        total_symptoms = len(query)
        match_threshold = int(total_symptoms * 0.8)
        print(f"Processing {total_symptoms} symptoms: {query}")
        
        # Pre-compute all symptom embeddings
        symptom_embeddings = {s: Settings.embed_model.get_text_embedding(s) for s in query}
        
        # First pass - collect all diseases with their embeddings
        for symptom in query:
            query_embedding = symptom_embeddings[symptom]
            vector_store_query = VectorStoreQuery(
                query_embedding=query_embedding,
                similarity_top_k=1,
                filters=phenotype_filter,
            )
            individual_results = vector_store.query(vector_store_query)
            individual_results = zip(individual_results.ids, individual_results.similarities)

            for node_id, score in individual_results:
                kg_node = graph_store.get(ids=[node_id])[0]
                
                # Get related diseases from graph
                graph_nodes = graph_store.structured_query(
                    """
                    MATCH (e:Node__) WHERE id(e) == $ids
                    MATCH p=(e)-[r:Relation__{label:"disease-phenotype-positive"}]-(t) 
                    RETURN DISTINCT id(t), t.Props__.node_name, t.Chunk__.text
                    """, 
                    param_map={"ids": node_id}
                )
                
                # Process each related disease
                for node in graph_nodes:
                    disease_id = node['id(t)']
                    disease_name = node['t.Props__.node_name']
                    node_text = f"{disease_name}: {node.get('t.Chunk__.text', '')}"
                    
                    # Add or update disease in counter
                    if disease_id not in disease_counter:
                        disease_counter[disease_id] = {
                            'index': disease_id,
                            'name': disease_name,
                            'embedding': Settings.embed_model.get_text_embedding(node_text),
                            'count': 1,
                            'symptoms': [symptom],
                            'cross_similarities': {}
                        }
                    else:
                        disease_counter[disease_id]['count'] += 1
                        if symptom not in disease_counter[disease_id]['symptoms']:
                            disease_counter[disease_id]['symptoms'].append(symptom)
        
        # Second pass - calculate cross-similarities
        for data in disease_counter.values():
            # Calculate cross-similarity for each unmatched symptom
            unmatched_symptoms = [s for s in query if s not in data['symptoms']]
            
            if unmatched_symptoms:
                similarities = {
                    symptom: dot(data['embedding'], symptom_embeddings[symptom]) / 
                             (norm(data['embedding']) * norm(symptom_embeddings[symptom]))
                    for symptom in unmatched_symptoms
                }
                data['cross_similarities'] = similarities
                data['avg_cross_similarity'] = sum(similarities.values()) / len(similarities)
            else:
                data['avg_cross_similarity'] = 0  # No unmatched symptoms
                data['cross_similarities'] = {}
                
            # Add the number of matched symptoms to the similarity score
            data['combined_score'] = data['count'] + data['avg_cross_similarity']
        
        # Group diseases by count
        count_groups = defaultdict(list)
        for disease_id, data in disease_counter.items():
            count_groups[data['count']].append((disease_id, data))
        
        # Sort each group by combined score
        for count, disease_list in count_groups.items():
            count_groups[count] = sorted(disease_list, key=lambda x: x[1]['combined_score'], reverse=True)
        
        # Display results by count groups in descending order
        count_keys = sorted(count_groups.keys(), reverse=True)
        
        print("\n=== Diseases grouped by symptom match count ===")
        for count in count_keys:
            print(f"\n--- Group: {count}/{total_symptoms} symptoms matched ---")
            for disease_id, data in count_groups[count]:
                # Display combined score (count + avg_cross_similarity)
                print(f"ID: {data['index']} | Disease: {data['name']} | "
                      f"Avg cross-similarity: {data['combined_score']:.5f}")
                
                # Show cross-similarities for individual unmatched symptoms
                if data['cross_similarities']:
                    for symptom, score in data['cross_similarities'].items():
                        print(f"    - {symptom}: {score:.5f}")
                    print('\n')
        
        return {
            'count_groups': {k: v for k, v in count_groups.items()},
            'all_diseases': disease_counter
        }

## Chat

In [None]:
user = ["Absence of subcutaneous fat",
"Generalized abnormality of skin",
"Micrognathia",
"Narrow mouth",
"Premature skin wrinkling",
] # Hutchinson-Gilford progeria syndrome

# ["Breast carcinoma",
#       "Neoplasm of the pancreas",
#       "Ovarian neoplasm",
#       "Abnormal fallopian tube morphology",
#       "Prostate cancer",
# ] # Hereditary breast and ovarian cancer syndrome


context = SymptomsMode(vector_store, graph_store).retrieve(user)


In [None]:
n.stop()