In [None]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

# These classes are assumed to be defined in your environment
from ChromaVDB.chroma import ChromaFramework
from DeepGraphDB import DeepGraphDB

gdb = DeepGraphDB()
gdb.load_graph("/home/cc/PHD/dglframework/DeepKG/DeepGraphDB/graphs/primekg.bin")
vdb = ChromaFramework(persist_directory="./ChromaVDB/chroma_db")

records = vdb.list_records()

names = [record['name'] for record in records]
entities = [record['entity'] for record in records]
embs = [record['embeddings'] for record in records]
ids = [record['id'] for record in records]

# Load and filter patient data
# data = pd.read_csv("/home/cc/PHD/dglframework/cptac/patient_gene_matrix_BRCA.csv", low_memory=False)
# # data = data[(data['site_of_resection_or_biopsy'] == 'Breast, NOS') & (data['primary_diagnosis'].isin(['Infiltrating duct carcinoma, NOS','Lobular carcinoma, NOS']))]
# data = data[(data['site_of_resection_or_biopsy'] == 'Breast, NOS') & (data['primary_diagnosis'] == 'Infiltrating duct carcinoma, NOS' )]

data = pd.read_excel('data/2025_03_29.xlsx') # (Diffuse Large B-cell Lymphoma)

In [None]:
etypes = [ tup for tup in gdb.graph.canonical_etypes if tup[0] == 'geneprotein' or tup[1] == 'geneprotein' ]

src_type, _, dst_type = etypes[0]
src, dst = gdb.graph.edges(etype=etypes[0])

In [None]:
src.max()

In [None]:
# dfstring = '_mutated'
dfstring = '_plasma_MUT'

mutation_counts = data.filter(like=dfstring).sum()

mutation_counts = mutation_counts.sort_values(ascending=False)
# 2. Filter for genes with at least 1 mutation
genes_with_mutations = mutation_counts[mutation_counts >= 1]

# 3. Calculate the mean of this filtered group
mean_mutation_count = genes_with_mutations.mean()

print("--- Genes With at Least 1 Mutation ---")
print(len(genes_with_mutations))

print(f"\n--- Mean Number of Mutations ---")
print(mean_mutation_count)

In [None]:
genes_with_mutations

In [None]:
# diseases = {name: index for index, name in enumerate(names) if "breast carcinoma" in name.lower() and "duct" in name.lower()}
diseases = {name: index for index, name in enumerate(names) if "diffuse large b-cell lymphoma" in name.lower()}

In [None]:
# Define how many top results you want to see for each disease
TOP_K = 10

# 1. Filter for all 'geneproteint' embeddings and their names
print("Filtering for gene/protein embeddings...")
gene_protein_data = [
    (name, emb) for name, entity, emb in zip(names, entities, embs) if entity == 'geneprotein'
]
# Unzip into separate lists
gene_protein_names, gene_protein_embs_list = zip(*gene_protein_data)

# Convert the list of embeddings into a single, efficient 2D tensor
gene_protein_embs_tensor = torch.tensor(gene_protein_embs_list, dtype=torch.float32)
print(f"Found {len(gene_protein_names)} gene/protein entities.")


# 2. Iterate through your target diseases and find similar embeddings
print("\n--- Finding Most Similar Gene/Proteins ---")
for disease_name, disease_index in diseases.items():
    print(f"\nDisease: {disease_name}")

    # Get the embedding for the current disease and convert it to a tensor
    disease_emb = torch.tensor(embs[disease_index], dtype=torch.float32)

    # Calculate cosine similarity between the disease and ALL gene/protein embeddings
    # We use unsqueeze(0) to make the disease_emb 2D for broadcasting [1, D] vs [N, D]
    similarities = F.cosine_similarity(disease_emb.unsqueeze(0), gene_protein_embs_tensor)

    # Get the top K results (both values and their indices)
    top_results = torch.topk(similarities, k=TOP_K)
    
    # 3. Display the results
    for i in range(TOP_K):
        score = top_results.values[i].item()
        gene_index = top_results.indices[i].item()
        gene_name = gene_protein_names[gene_index]
        print(f"  {i+1}. {gene_name} (Similarity: {score:.4f}) - mutation count: {genes_with_mutations.get(gene_name+dfstring, 0)}")